muxado/
stream_manager.rs

1use std::{
2    collections::HashMap,
3    pin::Pin,
4    sync::{
5        atomic::{
6            AtomicBool,
7            Ordering,
8        },
9        Arc,
10    },
11    task::{
12        Context,
13        Poll,
14        Waker,
15    },
16};
17
18use futures::{
19    channel::mpsc,
20    future::poll_fn,
21    lock::{
22        BiLock,
23        BiLockGuard,
24    },
25    prelude::*,
26    ready,
27    stream::{
28        FusedStream,
29        FuturesUnordered,
30        Stream as StreamT,
31        StreamFuture,
32    },
33};
34use tracing::{
35    debug,
36    error,
37    instrument,
38    trace,
39};
40
41use crate::{
42    errors::Error,
43    frame::{
44        Body,
45        Frame,
46        HeaderType,
47        StreamID,
48    },
49    stream::Stream,
50    stream_output::*,
51};
52
53#[derive(Debug)]
54pub struct SharedStreamManager(BiLock<StreamManager>, Arc<AtomicBool>);
55
56#[derive(Clone, Debug)]
57pub(crate) struct StreamHandle {
58    // Channel to send frames from the remote to the stream.
59    pub to_stream: mpsc::Sender<Frame>,
60
61    // Handle to close the stream's frame sink with a code from an `rst` or
62    // similar
63    pub sink_closer: SinkCloser,
64
65    pub needs_fin: bool,
66
67    // Whether our writer is closed
68    pub data_write_closed: bool,
69
70    // Track the bytes in/wndinc out so we can send goaways in the event of a
71    // misbehaving remote.
72    pub window: usize,
73}
74
75type StreamTasks = FuturesUnordered<WithID<StreamFuture<mpsc::Receiver<Frame>>>>;
76
77#[derive(Debug)]
78pub struct StreamManager {
79    stream_limit: usize,
80
81    streams: HashMap<StreamID, StreamHandle>,
82    sys_tx: mpsc::Sender<Frame>,
83    sys_rx: mpsc::Receiver<Frame>,
84    tasks: StreamTasks,
85
86    last_local_id: StreamID,
87    last_remote_id: StreamID,
88
89    gone_away: bool,
90
91    // If we run out of streams to poll, the task collection will be put to
92    // sleep. We can't immediately poll it when we add a new stream since that
93    // may lose a frame. Instead, the poll_next implementation will store its
94    // waker here, and we'll wake it up in create_stream to get it polling
95    // again.
96    new_streams: Option<Waker>,
97}
98
99impl StreamManager {
100    fn tasks(&mut self) -> Pin<&mut StreamTasks> {
101        Pin::new(&mut self.tasks)
102    }
103
104    fn sys_rx(&mut self) -> Pin<&mut mpsc::Receiver<Frame>> {
105        Pin::new(&mut self.sys_rx)
106    }
107
108    pub fn new(stream_limit: usize, client: bool) -> Self {
109        let (sys_tx, sys_rx) = mpsc::channel(512);
110        let mut last_local_id = 0;
111        let mut last_remote_id = 0;
112        if client {
113            last_local_id += 1;
114        } else {
115            last_remote_id += 1;
116        }
117        StreamManager {
118            streams: Default::default(),
119            stream_limit,
120            sys_tx,
121            sys_rx,
122            last_local_id: StreamID::clamp(last_local_id),
123            last_remote_id: StreamID::clamp(last_remote_id),
124            tasks: Default::default(),
125            gone_away: false,
126            new_streams: None,
127        }
128    }
129
130    // Split the manager into two shared halves.
131    pub fn split(self) -> (SharedStreamManager, SharedStreamManager) {
132        let (l, r) = BiLock::new(self);
133        let terminated = Arc::new(AtomicBool::new(false));
134        (
135            SharedStreamManager(l, terminated.clone()),
136            SharedStreamManager(r, terminated),
137        )
138    }
139
140    pub fn go_away(&mut self, error: Error) {
141        self.gone_away = true;
142        for (_id, handle) in self.streams.drain() {
143            handle.sink_closer.close_with(error);
144        }
145    }
146
147    pub fn sys_sender(&self) -> mpsc::Sender<Frame> {
148        self.sys_tx.clone()
149    }
150
151    pub fn close_senders(&mut self) {
152        for (_, stream) in self.streams.iter_mut() {
153            stream.to_stream.close_channel()
154        }
155    }
156}
157
158impl FusedStream for StreamManager {
159    fn is_terminated(&self) -> bool {
160        self.gone_away
161    }
162}
163
164impl FusedStream for SharedStreamManager {
165    fn is_terminated(&self) -> bool {
166        self.1.load(Ordering::SeqCst)
167    }
168}
169
170// Stream implementation for StreamManager
171// This is used as the "output" from all of the streams and will produce frames
172// that need to be sent to the remote via the underlying IO stream.
173impl StreamT for StreamManager {
174    type Item = Frame;
175
176    #[instrument(level = "trace", skip_all)]
177    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178        // There will only be no new frames if we've gone away.
179        if self.gone_away {
180            return Poll::Ready(None);
181        }
182
183        // Go ahead and store the latest waker for use by newly started streams.
184        // In order to start receiving wakeups from them, we have to ensure that
185        // our task collection gets polled here after one is added to it.
186        self.new_streams = Some(cx.waker().clone());
187
188        // Handle system frames first, but don't return if it's not ready, or
189        // it's somehow closed (shouldn't happen).
190        if let Poll::Ready(Some(mut frame)) = self.as_mut().sys_rx().poll_next(cx) {
191            if let Body::GoAway {
192                ref mut last_stream_id,
193                error: _,
194                message: _,
195            } = &mut frame.body
196            {
197                *last_stream_id = self.last_remote_id;
198                // We won't be sending any more frames from streams.
199                self.as_mut().tasks().clear();
200                self.as_mut().go_away(Error::SessionClosed);
201            }
202            return Some(frame).into();
203        }
204
205        // Otherwise, get the next frame from a stream.
206        let (id, (item, rest)) = if let Some(i) = ready!(self.as_mut().tasks().poll_next(cx)) {
207            i
208        } else {
209            return Poll::Pending;
210        };
211
212        let handle = if let Ok(handle) = self.get_stream(id) {
213            handle
214        } else {
215            // We only remove streams when the read/write end is dropped and we
216            // get None from it. We don't re-add it it to the future set then,
217            // so we can't receive any more frames from it here.
218            unreachable!();
219        };
220
221        // If the sink is closed and we don't need a fin, don't return a frame.
222        // We should never really see a case where we have a closed sink while a
223        // fin is needed, but make double sure.
224        // The sink closer is only closed from this end if a reset is received
225        // or issued. It's only closed from the other end if this end has gone
226        // away.
227        if handle.sink_closer.is_closed() && !handle.needs_fin {
228            debug!(needs_fin = handle.needs_fin, "removing stream without fin");
229            self.remove_stream(id);
230            cx.waker().wake_by_ref();
231            return Poll::Pending;
232        }
233
234        let frame = if let Some(frame) = item {
235            if let Body::WndInc(inc) = frame.body {
236                handle.window += *inc as usize;
237            }
238
239            if frame.is_fin() {
240                debug!(stream_id = debug(id), "setting needs_fin to false");
241                handle.needs_fin = false;
242            }
243
244            self.push_task(id, rest);
245
246            frame
247        } else {
248            // If we got None from the stream, it means its channel is closed
249            // because it got dropped on the other end. Maybe generate a fin and
250            // remove it from our map.
251
252            // Make sure we haven't already sent a fin for this stream. If we
253            // don't even know about the stream, it must have been removed by a
254            // remote reset. Don't generate a fin in that case.
255            let needs_fin = handle.needs_fin;
256            handle.needs_fin = false;
257            self.remove_stream(id);
258            debug!(needs_fin, "got none from stream, trying to send a fin");
259            if needs_fin {
260                debug!("removing stream and sending fin");
261                Frame::from(Body::Data([][..].into())).fin()
262            } else {
263                debug!("removing stream that's already fin'd");
264                // Could introduce a loop and `continue` here, or we could just
265                // return `Pending` and wake ourselves back up.
266                cx.waker().wake_by_ref();
267                return Poll::Pending;
268            }
269        }
270        .stream_id(id);
271
272        Some(frame).into()
273    }
274}
275
276impl StreamT for SharedStreamManager {
277    type Item = <StreamManager as StreamT>::Item;
278    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279        ready!(self.0.poll_lock(cx)).as_pin_mut().poll_next(cx)
280    }
281}
282
283impl SharedStreamManager {
284    #[instrument(level = "trace", skip(self))]
285    pub async fn go_away(&mut self, error: Error) {
286        self.1.store(true, Ordering::SeqCst);
287        self.lock().await.go_away(error);
288    }
289
290    // Send a frame to a stream with the given ID.
291    // Should only return an error if the stream is closed, and the caller needs to send a reset.
292    #[instrument(level = "trace", skip(self))]
293    pub async fn send_to_stream(&mut self, frame: Frame) -> Result<(), Error> {
294        let id = frame.header.stream_id;
295        let typ = frame.header.typ;
296        // If we see data coming in, reduce the stream's window. If it goes
297        // below 0, we'll reset the remote with a flow control error.
298        let mut shrink_window = if let Body::Data(bs) = &frame.body {
299            bs.len()
300        } else {
301            0
302        };
303
304        match frame.body {
305            Body::GoAway { .. } | Body::Invalid { .. } => {
306                error!(
307                    body = ?frame,
308                    id = %id,
309                    "attempt to send invalid frame type to stream",
310                );
311                return Err(Error::Internal);
312            }
313            _ => {}
314        }
315
316        let is_fin = frame.is_fin() && frame.body.len() == 0;
317
318        let mut frame = Some(frame);
319
320        let mut handle_frame = |handle: &mut StreamHandle, cx: &mut Context| {
321            if typ == HeaderType::Data && handle.data_write_closed {
322                debug!("attempt to send data on closed stream");
323                return Err(Error::StreamClosed).into();
324            }
325
326            // Don't send resets to the stream, just close its channel with the
327            // error.
328            if let Some(Frame {
329                body: Body::Rst(err),
330                ..
331            }) = frame
332            {
333                debug!(
334                    stream_id = display(id),
335                    error = display(err),
336                    "received rst from remote, closing stream"
337                );
338                // Close the writer on the other end, mark *our* writer as
339                // closed, and disable fin generation.
340                handle.sink_closer.close_with(err);
341                handle.data_write_closed = true;
342                handle.needs_fin = false;
343                return Ok(()).into();
344            }
345
346            // Keep track of how much data has been sent to the stream. If it
347            // goes over, send a reset.
348            if shrink_window <= handle.window {
349                handle.window -= shrink_window;
350                // We're polling this function, so we need to avoid shrinking
351                // more than once.
352                shrink_window = 0;
353            } else {
354                debug!(
355                    frame_size = shrink_window,
356                    stream_window = handle.window,
357                    "remote violated flow control"
358                );
359                return Err(Error::FlowControl).into();
360            }
361
362            let sink = &mut handle.to_stream;
363            trace!("checking stream for readiness");
364            ready!(sink.poll_ready(cx))
365                .and_then(|_| sink.start_send(frame.take().unwrap()))
366                .map_err(|_| Error::StreamClosed)
367                .or_else(|res| if is_fin { Ok(()) } else { Err(res) })?;
368            Ok(()).into()
369        };
370
371        // The rest of this is in a `poll_fn` so that we don't hold the lock for
372        // any longer than necessary to check if the stream is ready. If we did
373        // it await-style, we'd continue holding the lock even if the stream was
374        // still pending.
375        poll_fn(move |cx| -> Poll<Result<_, Error>> {
376            // Lock self, look up the stream. If it doesn't exist, return the
377            // error.
378            let mut lock = ready!(self.0.poll_lock(cx));
379            let handle = match lock.get_stream(id) {
380                Ok(handle) => handle,
381                Err(_e) if HeaderType::Data != typ || is_fin => {
382                    return Ok(()).into();
383                }
384                Err(e) => return Err(e).into(),
385            };
386
387            let res = ready!(handle_frame(handle, cx));
388
389            // Any errors from data frames should cause a reset to be sent by
390            // the session.
391            if HeaderType::Data == typ && !is_fin {
392                // If we're sending a reset, close all the writers to prevent
393                // any more frames from being sent.
394                if let Err(e) = res {
395                    debug!(error = display(e), "error handling frame");
396                    handle.sink_closer.close_with(Error::StreamClosed);
397                    handle.data_write_closed = true;
398                    handle.needs_fin = false;
399                }
400                res.into()
401            } else {
402                Ok(()).into()
403            }
404        })
405        .await
406    }
407
408    pub async fn close_senders(&mut self) {
409        self.lock().await.close_senders()
410    }
411
412    pub async fn lock(&mut self) -> BiLockGuard<'_, StreamManager> {
413        self.0.lock().await
414    }
415}
416
417pub struct OpenReq {
418    channel: (mpsc::Sender<Frame>, mpsc::Receiver<Frame>),
419    closer: SinkCloser,
420    window: usize,
421}
422
423impl OpenReq {
424    pub fn create(window: usize, needs_syn: bool) -> (OpenReq, Stream) {
425        let (to_stream, from_session) = mpsc::channel(window);
426        let (to_session, from_stream) = mpsc::channel(window);
427        let to_session = StreamSender::wrap(to_session);
428        let req = OpenReq {
429            channel: (to_stream, from_stream),
430            closer: to_session.closer(),
431            window,
432        };
433        let stream = Stream::new(to_session, from_session, window, needs_syn);
434        (req, stream)
435    }
436}
437
438impl StreamManager {
439    #[instrument(level = "trace", skip(self))]
440    pub(crate) fn get_stream(&mut self, id: StreamID) -> Result<&mut StreamHandle, Error> {
441        if let Some(handle) = self.streams.get_mut(&id) {
442            Ok(handle)
443        } else {
444            trace!("stream not found");
445            Err(Error::StreamClosed)
446        }
447    }
448
449    #[instrument(level = "trace", skip(self, req))]
450    pub fn create_stream(&mut self, id: Option<StreamID>, req: OpenReq) -> Result<StreamID, Error> {
451        // Only return an error if we're at the stream limit.
452        if self.streams.len() == self.stream_limit {
453            return Err(Error::StreamsExhausted);
454        }
455
456        let (to_stream, from_stream) = req.channel;
457        let closer = req.closer;
458        let window = req.window;
459        let id = if let Some(remote_id) = id {
460            self.last_remote_id = remote_id;
461            remote_id
462        } else {
463            let new_id = StreamID::clamp(*self.last_local_id + 2);
464            self.last_local_id = new_id;
465            new_id
466        };
467        self.streams.insert(
468            id,
469            StreamHandle {
470                window,
471                to_stream,
472                sink_closer: closer,
473                needs_fin: true,
474                data_write_closed: false,
475            },
476        );
477        self.push_task(id, from_stream);
478        // wake up the main stream if it put itself to sleep.
479        if let Some(w) = self.new_streams.take() {
480            w.wake()
481        }
482        Ok(id)
483    }
484
485    fn push_task(&mut self, id: StreamID, recv: mpsc::Receiver<Frame>) {
486        self.tasks.push(recv.into_future().with_id(id));
487    }
488
489    #[instrument(level = "debug", skip(self))]
490    fn remove_stream(&mut self, id: StreamID) -> Option<StreamHandle> {
491        self.streams.remove(&id)
492    }
493}
494
495struct WithID<F: ?Sized> {
496    id: StreamID,
497    fut: F,
498}
499
500impl<F: Unpin> WithID<F> {
501    fn fut(&mut self) -> Pin<&mut F> {
502        Pin::new(&mut self.fut)
503    }
504}
505
506impl<F> Future for WithID<F>
507where
508    F: Future + Unpin,
509{
510    type Output = (StreamID, F::Output);
511
512    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
513        let out = ready!(self.as_mut().fut().poll(cx));
514        Poll::Ready((self.id, out))
515    }
516}
517
518trait WithIDExt {
519    fn with_id(self, id: StreamID) -> WithID<Self>;
520}
521
522impl<F> WithIDExt for F
523where
524    F: Future,
525{
526    fn with_id(self, id: StreamID) -> WithID<Self> {
527        WithID { id, fut: self }
528    }
529}