1use std::{
2 cmp,
3 fmt,
4 io,
5 pin::Pin,
6 task::{
7 Context,
8 Poll,
9 Waker,
10 },
11};
12
13use bytes::BytesMut;
14use futures::{
15 channel::mpsc,
16 ready,
17 sink::Sink,
18 stream::Stream as StreamT,
19};
20use pin_project::pin_project;
21use tokio::io::{
22 AsyncRead,
23 AsyncWrite,
24 ReadBuf,
25};
26use tracing::instrument;
27
28use crate::{
29 errors::Error,
30 frame::{
31 Body,
32 Frame,
33 HeaderType,
34 Length,
35 WndInc,
36 },
37 stream_output::StreamSender,
38 window::Window,
39};
40
41#[pin_project(project = StreamProj, PinnedDrop)]
46pub struct Stream {
47 pub(crate) dropref: Option<awaitdrop::Ref>,
48
49 window: Window,
50
51 read_buf: BytesMut,
52
53 #[pin]
57 fin: mpsc::Receiver<Frame>,
58 #[pin]
59 fout: StreamSender,
60
61 read_waker: Option<Waker>,
62 write_waker: Option<Waker>,
63
64 write_closed: Option<Error>,
65
66 data_read_closed: bool,
67
68 needs_syn: bool,
69}
70
71impl fmt::Debug for Stream {
72 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73 f.debug_struct("Stream")
74 .field("window", &self.window)
75 .field("read_buf", &self.read_buf)
76 .field("read_waker", &self.read_waker)
77 .field("write_waker", &self.write_waker)
78 .field("reset", &self.write_closed)
79 .field("read_closed", &self.data_read_closed)
80 .finish()
81 }
82}
83
84impl Stream {
85 pub(crate) fn new(
86 fout: StreamSender,
87 fin: mpsc::Receiver<Frame>,
88 window_size: usize,
89 needs_syn: bool,
90 ) -> Self {
91 Self {
92 dropref: None,
93 window: Window::new(window_size),
94 fin,
95 fout,
96 read_buf: Default::default(),
97 read_waker: Default::default(),
98 write_waker: Default::default(),
99 write_closed: Default::default(),
100 data_read_closed: false,
101 needs_syn,
102 }
103 }
104
105 #[instrument(level = "trace", skip_all)]
106 fn poll_recv_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Frame>> {
107 let mut this = self.project();
108 let fin = this.fin.as_mut();
109 fin.poll_next(cx)
110 }
111
112 #[instrument(level = "trace", skip_all)]
117 fn poll_recv_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
118 self.poll_recv_frame_type(cx, HeaderType::Data)
119 }
120
121 #[instrument(level = "trace", skip_all)]
122 fn poll_recv_wndinc(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
123 self.poll_recv_frame_type(cx, HeaderType::WndInc)
124 }
125
126 #[instrument(level = "trace", skip(self, cx))]
127 fn poll_recv_frame_type(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 target_typ: HeaderType,
131 ) -> Poll<io::Result<()>> {
132 loop {
133 let frame = if let Some(frame) = ready!(self.as_mut().poll_recv_frame(cx)) {
134 frame
135 } else {
136 self.data_read_closed = true;
137 return Ok(()).into();
138 };
139
140 let typ = self.handle_frame(frame, Some(cx));
141
142 if typ == target_typ {
143 return Poll::Ready(Ok(()));
144 }
145 }
146 }
147
148 #[instrument(level = "trace", skip(self, cx))]
149 fn poll_send_wndinc(self: Pin<&mut Self>, cx: &mut Context<'_>, by: WndInc) -> Poll<()> {
150 let mut this = self.project();
151 if ready!(this.fout.as_mut().poll_ready(cx)).is_err() {
153 return Poll::Ready(());
154 }
155
156 let _ = this.fout.as_mut().start_send(Body::WndInc(by).into());
158
159 Poll::Ready(())
160 }
161
162 #[instrument(level = "trace", skip(self, cx))]
163 fn handle_frame(&mut self, frame: Frame, cx: Option<&Context<'_>>) -> HeaderType {
164 if frame.is_fin() {
165 self.data_read_closed = true;
166 }
167 match frame.body {
168 Body::Data(bs) => {
169 self.read_buf.extend_from_slice(&bs);
170 self.maybe_wake_read(cx);
171 }
172 Body::WndInc(by) => {
173 self.window.inc(*by as usize);
174 self.maybe_wake_write(cx);
175 }
176 _ => unreachable!("stream should never receive GoAway, Rst or Invalid frames"),
177 }
178 frame.header.typ
179 }
180
181 #[instrument(level = "trace", skip_all)]
182 fn maybe_wake_read(&mut self, cx: Option<&Context>) {
183 maybe_wake(cx, self.read_waker.take())
184 }
185 #[instrument(level = "trace", skip_all)]
186 fn maybe_wake_write(&mut self, cx: Option<&Context>) {
187 maybe_wake(cx, self.write_waker.take())
188 }
189}
190
191impl<'a> StreamProj<'a> {
192 fn closed_err(&mut self, code: Error) -> io::Error {
193 *self.write_closed = Some(code);
194 io::Error::new(io::ErrorKind::ConnectionReset, code)
195 }
196}
197
198fn maybe_wake(me: Option<&Context>, other: Option<Waker>) {
199 match (me.map(Context::waker), other) {
200 (Some(me), Some(other)) if !other.will_wake(me) => other.wake(),
201 (None, Some(other)) => other.wake(),
202 _ => {}
203 }
204}
205
206impl AsyncRead for Stream {
207 #[instrument(level = "trace", skip_all)]
208 fn poll_read(
209 mut self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 buf: &mut ReadBuf<'_>,
212 ) -> Poll<io::Result<()>> {
213 loop {
214 if !self.read_buf.is_empty() {
216 let max = cmp::min(self.read_buf.len(), buf.remaining());
217 let clamped = WndInc::clamp(max as u32);
218 let n = *clamped as usize;
219
220 if n > 0 {
221 ready!(self.as_mut().poll_send_wndinc(cx, clamped));
224
225 buf.put_slice(self.read_buf.split_to(n).as_ref());
226 }
227
228 return Poll::Ready(Ok(()));
229 }
230
231 if self.data_read_closed {
233 return Poll::Ready(Ok(()));
234 }
235
236 self.read_waker = Some(cx.waker().clone());
239
240 ready!(self.as_mut().poll_recv_data(cx))?;
242 }
243 }
244}
245
246impl AsyncWrite for Stream {
247 #[instrument(level = "trace", skip(self, cx))]
248 fn poll_write(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &[u8],
252 ) -> Poll<Result<usize, io::Error>> {
253 if let Some(code) = self.write_closed {
254 return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, code)));
255 }
256
257 if self.window.capacity() == 0 {
258 self.write_waker = Some(cx.waker().clone());
259
260 ready!(self.as_mut().poll_recv_wndinc(cx))?;
261 }
262
263 let mut this = self.project();
264
265 ready!(this.fout.as_mut().poll_ready(cx)).map_err(|e| this.closed_err(e))?;
266
267 let wincap = this.window.capacity();
268
269 let max_len = Length::clamp(buf.len() as u32);
270
271 let send_len = cmp::min(wincap, *max_len as usize);
272
273 let bs = BytesMut::from(&buf[..send_len]);
274
275 let mut frame: Frame = Body::Data(bs.freeze()).into();
276 if *this.needs_syn {
277 *this.needs_syn = false;
278 frame = frame.syn();
279 }
280
281 this.fout
282 .as_mut()
283 .start_send(frame)
284 .map_err(|e| this.closed_err(e))?;
285
286 let _dec = this.window.dec(send_len);
287 debug_assert!(_dec == send_len);
288
289 Poll::Ready(Ok(send_len))
290 }
291
292 #[instrument(level = "trace", skip_all)]
293 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
294 let mut this = self.project();
295 this.fout
296 .as_mut()
297 .poll_flush(cx)
298 .map_err(|e| this.closed_err(e))
299 }
300
301 #[instrument(level = "trace", skip_all)]
302 fn poll_shutdown(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context<'_>,
305 ) -> Poll<Result<(), io::Error>> {
306 let mut this = self.as_mut().project();
310 if this.write_closed.is_none() {
311 ready!(this.fout.as_mut().poll_ready(cx))
312 .and_then(|_| {
313 this.fout
314 .as_mut()
315 .start_send(Frame::from(Body::Data([][..].into())).fin())
316 })
317 .map_err(|e| this.closed_err(e))?;
318 *this.write_closed = Some(Error::StreamClosed);
319 }
320
321 this.fout
322 .as_mut()
323 .poll_flush(cx)
324 .map_err(|e| this.closed_err(e))
325 }
326}
327
328#[pin_project::pinned_drop]
329impl PinnedDrop for Stream {
330 #[instrument(level = "trace", skip_all)]
331 fn drop(self: Pin<&mut Self>) {}
332}
333
334#[cfg(test)]
335pub mod test {
336 use std::time::Duration;
337
338 use tokio::{
339 io::{
340 AsyncReadExt,
341 AsyncWriteExt,
342 },
343 time,
344 };
345 use tracing_test::traced_test;
346
347 use super::*;
348
349 #[traced_test]
350 #[tokio::test]
351 async fn test_stream() {
352 let (mut tx, stream_rx) = mpsc::channel(512);
353 let (stream_tx, mut rx) = mpsc::channel(512);
354 let stream_tx = StreamSender::wrap(stream_tx);
355
356 let mut stream = Stream::new(stream_tx, stream_rx, 5, true);
357
358 const MSG: &str = "Hello, world!";
359 const MSG2: &str = "Hello to you too!";
360
361 let n = time::timeout(Duration::from_secs(1), stream.write(MSG2.as_bytes()))
363 .await
364 .unwrap()
365 .unwrap();
366
367 assert_eq!(n, 5);
368 let resp = rx.try_next().unwrap().unwrap();
369 assert_eq!(resp, Frame::from(Body::Data(MSG2[0..5].into())).syn());
370
371 tx.try_send(Body::WndInc(WndInc::clamp(5)).into()).unwrap();
373 tx.try_send(Body::Data(MSG.as_bytes().into()).into())
374 .unwrap();
375 drop(tx);
376
377 let mut buf = String::new();
379 time::timeout(Duration::from_secs(1), stream.read_to_string(&mut buf))
380 .await
381 .unwrap()
382 .unwrap();
383 assert_eq!(buf, "Hello, world!");
384 let resp = rx.try_next().unwrap().unwrap();
386 assert_eq!(resp, Body::WndInc(WndInc::clamp(MSG.len() as u32)).into());
387
388 let n = time::timeout(Duration::from_secs(1), stream.write(&MSG2.as_bytes()[5..]))
390 .await
391 .unwrap()
392 .unwrap();
393 assert_eq!(n, 5);
394 let resp = rx.try_next().unwrap().unwrap();
395 assert_eq!(resp, Body::Data(MSG2[5..10].into()).into());
396
397 stream.shutdown().await.unwrap();
398
399 assert!(rx.try_next().unwrap().unwrap().is_fin());
400 }
401}