muxado/
stream.rs

1use std::{
2    cmp,
3    fmt,
4    io,
5    pin::Pin,
6    task::{
7        Context,
8        Poll,
9        Waker,
10    },
11};
12
13use bytes::BytesMut;
14use futures::{
15    channel::mpsc,
16    ready,
17    sink::Sink,
18    stream::Stream as StreamT,
19};
20use pin_project::pin_project;
21use tokio::io::{
22    AsyncRead,
23    AsyncWrite,
24    ReadBuf,
25};
26use tracing::instrument;
27
28use crate::{
29    errors::Error,
30    frame::{
31        Body,
32        Frame,
33        HeaderType,
34        Length,
35        WndInc,
36    },
37    stream_output::StreamSender,
38    window::Window,
39};
40
41/// A muxado stream.
42///
43/// This is an [AsyncRead]/[AsyncWrite] struct that's backed by a muxado
44/// session.
45#[pin_project(project = StreamProj, PinnedDrop)]
46pub struct Stream {
47    pub(crate) dropref: Option<awaitdrop::Ref>,
48
49    window: Window,
50
51    read_buf: BytesMut,
52
53    // These are the two channels that are used to shuttle data back and forth
54    // between the stream and the stream manager, which is responsible for
55    // routing frames to their proper stream.
56    #[pin]
57    fin: mpsc::Receiver<Frame>,
58    #[pin]
59    fout: StreamSender,
60
61    read_waker: Option<Waker>,
62    write_waker: Option<Waker>,
63
64    write_closed: Option<Error>,
65
66    data_read_closed: bool,
67
68    needs_syn: bool,
69}
70
71impl fmt::Debug for Stream {
72    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73        f.debug_struct("Stream")
74            .field("window", &self.window)
75            .field("read_buf", &self.read_buf)
76            .field("read_waker", &self.read_waker)
77            .field("write_waker", &self.write_waker)
78            .field("reset", &self.write_closed)
79            .field("read_closed", &self.data_read_closed)
80            .finish()
81    }
82}
83
84impl Stream {
85    pub(crate) fn new(
86        fout: StreamSender,
87        fin: mpsc::Receiver<Frame>,
88        window_size: usize,
89        needs_syn: bool,
90    ) -> Self {
91        Self {
92            dropref: None,
93            window: Window::new(window_size),
94            fin,
95            fout,
96            read_buf: Default::default(),
97            read_waker: Default::default(),
98            write_waker: Default::default(),
99            write_closed: Default::default(),
100            data_read_closed: false,
101            needs_syn,
102        }
103    }
104
105    #[instrument(level = "trace", skip_all)]
106    fn poll_recv_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Frame>> {
107        let mut this = self.project();
108        let fin = this.fin.as_mut();
109        fin.poll_next(cx)
110    }
111
112    // Receive data and fill the read buffer.
113    // Handle frames of any other type along the way.
114    // Returns `Poll::Ready` once there are new bytes to read, or EOF/RST has
115    // been reached.
116    #[instrument(level = "trace", skip_all)]
117    fn poll_recv_data(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
118        self.poll_recv_frame_type(cx, HeaderType::Data)
119    }
120
121    #[instrument(level = "trace", skip_all)]
122    fn poll_recv_wndinc(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
123        self.poll_recv_frame_type(cx, HeaderType::WndInc)
124    }
125
126    #[instrument(level = "trace", skip(self, cx))]
127    fn poll_recv_frame_type(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        target_typ: HeaderType,
131    ) -> Poll<io::Result<()>> {
132        loop {
133            let frame = if let Some(frame) = ready!(self.as_mut().poll_recv_frame(cx)) {
134                frame
135            } else {
136                self.data_read_closed = true;
137                return Ok(()).into();
138            };
139
140            let typ = self.handle_frame(frame, Some(cx));
141
142            if typ == target_typ {
143                return Poll::Ready(Ok(()));
144            }
145        }
146    }
147
148    #[instrument(level = "trace", skip(self, cx))]
149    fn poll_send_wndinc(self: Pin<&mut Self>, cx: &mut Context<'_>, by: WndInc) -> Poll<()> {
150        let mut this = self.project();
151        // Treat a closed send channel as "success"
152        if ready!(this.fout.as_mut().poll_ready(cx)).is_err() {
153            return Poll::Ready(());
154        }
155
156        // Same as above
157        let _ = this.fout.as_mut().start_send(Body::WndInc(by).into());
158
159        Poll::Ready(())
160    }
161
162    #[instrument(level = "trace", skip(self, cx))]
163    fn handle_frame(&mut self, frame: Frame, cx: Option<&Context<'_>>) -> HeaderType {
164        if frame.is_fin() {
165            self.data_read_closed = true;
166        }
167        match frame.body {
168            Body::Data(bs) => {
169                self.read_buf.extend_from_slice(&bs);
170                self.maybe_wake_read(cx);
171            }
172            Body::WndInc(by) => {
173                self.window.inc(*by as usize);
174                self.maybe_wake_write(cx);
175            }
176            _ => unreachable!("stream should never receive GoAway, Rst or Invalid frames"),
177        }
178        frame.header.typ
179    }
180
181    #[instrument(level = "trace", skip_all)]
182    fn maybe_wake_read(&mut self, cx: Option<&Context>) {
183        maybe_wake(cx, self.read_waker.take())
184    }
185    #[instrument(level = "trace", skip_all)]
186    fn maybe_wake_write(&mut self, cx: Option<&Context>) {
187        maybe_wake(cx, self.write_waker.take())
188    }
189}
190
191impl<'a> StreamProj<'a> {
192    fn closed_err(&mut self, code: Error) -> io::Error {
193        *self.write_closed = Some(code);
194        io::Error::new(io::ErrorKind::ConnectionReset, code)
195    }
196}
197
198fn maybe_wake(me: Option<&Context>, other: Option<Waker>) {
199    match (me.map(Context::waker), other) {
200        (Some(me), Some(other)) if !other.will_wake(me) => other.wake(),
201        (None, Some(other)) => other.wake(),
202        _ => {}
203    }
204}
205
206impl AsyncRead for Stream {
207    #[instrument(level = "trace", skip_all)]
208    fn poll_read(
209        mut self: Pin<&mut Self>,
210        cx: &mut Context<'_>,
211        buf: &mut ReadBuf<'_>,
212    ) -> Poll<io::Result<()>> {
213        loop {
214            // If we have data, return it
215            if !self.read_buf.is_empty() {
216                let max = cmp::min(self.read_buf.len(), buf.remaining());
217                let clamped = WndInc::clamp(max as u32);
218                let n = *clamped as usize;
219
220                if n > 0 {
221                    // Wait till there's window capacity to receive an increment.
222                    // If this fails, continue anyway.
223                    ready!(self.as_mut().poll_send_wndinc(cx, clamped));
224
225                    buf.put_slice(self.read_buf.split_to(n).as_ref());
226                }
227
228                return Poll::Ready(Ok(()));
229            }
230
231            // EOF's should return Ok without modifying the output buffer.
232            if self.data_read_closed {
233                return Poll::Ready(Ok(()));
234            }
235
236            // Data frames may be ingested by the writer as well, so make sure
237            // we don't get forgotten.
238            self.read_waker = Some(cx.waker().clone());
239
240            // Otherwise, try to get more.
241            ready!(self.as_mut().poll_recv_data(cx))?;
242        }
243    }
244}
245
246impl AsyncWrite for Stream {
247    #[instrument(level = "trace", skip(self, cx))]
248    fn poll_write(
249        mut self: Pin<&mut Self>,
250        cx: &mut Context<'_>,
251        buf: &[u8],
252    ) -> Poll<Result<usize, io::Error>> {
253        if let Some(code) = self.write_closed {
254            return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, code)));
255        }
256
257        if self.window.capacity() == 0 {
258            self.write_waker = Some(cx.waker().clone());
259
260            ready!(self.as_mut().poll_recv_wndinc(cx))?;
261        }
262
263        let mut this = self.project();
264
265        ready!(this.fout.as_mut().poll_ready(cx)).map_err(|e| this.closed_err(e))?;
266
267        let wincap = this.window.capacity();
268
269        let max_len = Length::clamp(buf.len() as u32);
270
271        let send_len = cmp::min(wincap, *max_len as usize);
272
273        let bs = BytesMut::from(&buf[..send_len]);
274
275        let mut frame: Frame = Body::Data(bs.freeze()).into();
276        if *this.needs_syn {
277            *this.needs_syn = false;
278            frame = frame.syn();
279        }
280
281        this.fout
282            .as_mut()
283            .start_send(frame)
284            .map_err(|e| this.closed_err(e))?;
285
286        let _dec = this.window.dec(send_len);
287        debug_assert!(_dec == send_len);
288
289        Poll::Ready(Ok(send_len))
290    }
291
292    #[instrument(level = "trace", skip_all)]
293    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
294        let mut this = self.project();
295        this.fout
296            .as_mut()
297            .poll_flush(cx)
298            .map_err(|e| this.closed_err(e))
299    }
300
301    #[instrument(level = "trace", skip_all)]
302    fn poll_shutdown(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context<'_>,
305    ) -> Poll<Result<(), io::Error>> {
306        // Rather than close the output channel, send a fin frame.
307        // This lets us use the actual channel closure as the "stream is gone
308        // for good" signal.
309        let mut this = self.as_mut().project();
310        if this.write_closed.is_none() {
311            ready!(this.fout.as_mut().poll_ready(cx))
312                .and_then(|_| {
313                    this.fout
314                        .as_mut()
315                        .start_send(Frame::from(Body::Data([][..].into())).fin())
316                })
317                .map_err(|e| this.closed_err(e))?;
318            *this.write_closed = Some(Error::StreamClosed);
319        }
320
321        this.fout
322            .as_mut()
323            .poll_flush(cx)
324            .map_err(|e| this.closed_err(e))
325    }
326}
327
328#[pin_project::pinned_drop]
329impl PinnedDrop for Stream {
330    #[instrument(level = "trace", skip_all)]
331    fn drop(self: Pin<&mut Self>) {}
332}
333
334#[cfg(test)]
335pub mod test {
336    use std::time::Duration;
337
338    use tokio::{
339        io::{
340            AsyncReadExt,
341            AsyncWriteExt,
342        },
343        time,
344    };
345    use tracing_test::traced_test;
346
347    use super::*;
348
349    #[traced_test]
350    #[tokio::test]
351    async fn test_stream() {
352        let (mut tx, stream_rx) = mpsc::channel(512);
353        let (stream_tx, mut rx) = mpsc::channel(512);
354        let stream_tx = StreamSender::wrap(stream_tx);
355
356        let mut stream = Stream::new(stream_tx, stream_rx, 5, true);
357
358        const MSG: &str = "Hello, world!";
359        const MSG2: &str = "Hello to you too!";
360
361        // First try a short write, the window won't permit more
362        let n = time::timeout(Duration::from_secs(1), stream.write(MSG2.as_bytes()))
363            .await
364            .unwrap()
365            .unwrap();
366
367        assert_eq!(n, 5);
368        let resp = rx.try_next().unwrap().unwrap();
369        assert_eq!(resp, Frame::from(Body::Data(MSG2[0..5].into())).syn());
370
371        // Next, send the stream an inc and a data frame
372        tx.try_send(Body::WndInc(WndInc::clamp(5)).into()).unwrap();
373        tx.try_send(Body::Data(MSG.as_bytes().into()).into())
374            .unwrap();
375        drop(tx);
376
377        // Read the data. The wndinc should get processed as well.
378        let mut buf = String::new();
379        time::timeout(Duration::from_secs(1), stream.read_to_string(&mut buf))
380            .await
381            .unwrap()
382            .unwrap();
383        assert_eq!(buf, "Hello, world!");
384        // Reading the data should generate a wndinc
385        let resp = rx.try_next().unwrap().unwrap();
386        assert_eq!(resp, Body::WndInc(WndInc::clamp(MSG.len() as u32)).into());
387
388        // Finally, try writing again. If the previous read handled the wndinc, we'll have capacity for 5 more bytes
389        let n = time::timeout(Duration::from_secs(1), stream.write(&MSG2.as_bytes()[5..]))
390            .await
391            .unwrap()
392            .unwrap();
393        assert_eq!(n, 5);
394        let resp = rx.try_next().unwrap().unwrap();
395        assert_eq!(resp, Body::Data(MSG2[5..10].into()).into());
396
397        stream.shutdown().await.unwrap();
398
399        assert!(rx.try_next().unwrap().unwrap().is_fin());
400    }
401}