muxado/
session.rs

1use std::{
2    io,
3    sync::{
4        atomic::{
5            AtomicBool,
6            Ordering,
7        },
8        Arc,
9    },
10};
11
12use async_trait::async_trait;
13use futures::{
14    channel::{
15        mpsc,
16        oneshot,
17    },
18    prelude::*,
19    select,
20    stream::StreamExt,
21    SinkExt,
22};
23use tokio::io::{
24    AsyncRead,
25    AsyncWrite,
26};
27use tokio_util::codec::Framed;
28use tracing::{
29    debug,
30    debug_span,
31    instrument,
32    trace,
33    Instrument,
34};
35
36use crate::{
37    codec::FrameCodec,
38    errors::Error,
39    frame::{
40        Body,
41        Frame,
42        Header,
43        HeaderType,
44        StreamID,
45    },
46    stream::Stream,
47    stream_manager::{
48        OpenReq,
49        SharedStreamManager,
50        StreamManager,
51    },
52};
53
54const DEFAULT_WINDOW: usize = 0x40000; // 256KB
55const DEFAULT_ACCEPT: usize = 64;
56const DEFAULT_STREAMS: usize = 512;
57
58/// Builder for a muxado session.
59///
60/// Should probably leave this alone unless you're sure you know what you're
61/// doing.
62pub struct SessionBuilder<S> {
63    io_stream: S,
64    window: usize,
65    accept_queue_size: usize,
66    stream_limit: usize,
67    client: bool,
68}
69
70impl<S> SessionBuilder<S>
71where
72    S: AsyncRead + AsyncWrite + Send + 'static,
73{
74    /// Start building a new muxado session using the provided IO stream.
75    pub fn new(io_stream: S) -> Self {
76        SessionBuilder {
77            io_stream,
78            window: DEFAULT_WINDOW,
79            accept_queue_size: DEFAULT_ACCEPT,
80            stream_limit: DEFAULT_STREAMS,
81            client: true,
82        }
83    }
84
85    /// Set the stream window size.
86    /// Defaults to 256kb.
87    pub fn window_size(mut self, size: usize) -> Self {
88        self.window = size;
89        self
90    }
91
92    /// Set the accept queue size.
93    /// This is the size of the channel that will hold "open stream" requests
94    /// from the remote. If [Accept::accept] isn't called and the
95    /// channel fills up, the session will block.
96    /// Defaults to 64.
97    pub fn accept_queue_size(mut self, size: usize) -> Self {
98        self.accept_queue_size = size;
99        self
100    }
101
102    /// Set the maximum number of streams allowed at a given time.
103    /// If this limit is reached, new streams will be refused.
104    /// Defaults to 512.
105    pub fn stream_limit(mut self, count: usize) -> Self {
106        self.stream_limit = count;
107        self
108    }
109
110    /// Set this session to act as a client.
111    pub fn client(mut self) -> Self {
112        self.client = true;
113        self
114    }
115
116    /// Set this session to act as a server.
117    pub fn server(mut self) -> Self {
118        self.client = false;
119        self
120    }
121
122    /// Start a muxado session with the current options.
123    pub fn start(self) -> MuxadoSession {
124        let SessionBuilder {
125            io_stream,
126            window,
127            accept_queue_size,
128            stream_limit,
129            client,
130        } = self;
131
132        let (accept_tx, accept_rx) = mpsc::channel(accept_queue_size);
133        let (open_tx, open_rx) = mpsc::channel(512);
134
135        let manager = StreamManager::new(stream_limit, client);
136        let sys_tx = manager.sys_sender();
137        let (m1, m2) = manager.split();
138
139        let (io_tx, io_rx) = Framed::new(io_stream, FrameCodec::default()).split();
140
141        let read_task = Reader {
142            io: io_rx,
143            accept_tx,
144            window,
145            manager: m1,
146            last_stream_processed: StreamID::clamp(0),
147            sys_tx: sys_tx.clone(),
148        };
149
150        let write_task = Writer {
151            window,
152            io: io_tx,
153            manager: m2,
154            open_reqs: open_rx,
155        };
156
157        let (dropref, waiter) = awaitdrop::awaitdrop();
158
159        tokio::spawn(
160            futures::future::select(
161                async move {
162                    let result = read_task.run().await;
163                    debug!(?result, "read_task exited");
164                }
165                .boxed(),
166                waiter.wait(),
167            )
168            .instrument(debug_span!("read_task")),
169        );
170        tokio::spawn(
171            futures::future::select(
172                async move {
173                    let result = write_task.run().await;
174                    debug!(?result, "write_task exited");
175                }
176                .boxed(),
177                waiter.wait(),
178            )
179            .instrument(debug_span!("write_task")),
180        );
181
182        MuxadoSession {
183            incoming: MuxadoAccept(dropref.clone(), accept_rx),
184            outgoing: MuxadoOpen {
185                dropref,
186                open_tx,
187                sys_tx,
188                closed: AtomicBool::from(false).into(),
189            },
190        }
191    }
192}
193
194// read task - runs until there are no more frames coming from the remote
195// Reads frames from the underlying stream and forwards them to the stream
196// manager.
197struct Reader<R> {
198    io: R,
199    sys_tx: mpsc::Sender<Frame>,
200    accept_tx: mpsc::Sender<Stream>,
201    window: usize,
202    manager: SharedStreamManager,
203    last_stream_processed: StreamID,
204}
205
206impl<R> Reader<R>
207where
208    R: futures::stream::Stream<Item = Result<Frame, io::Error>> + Unpin,
209{
210    /// Handle an incoming frame from the remote
211    #[instrument(level = "trace", skip(self))]
212    async fn handle_frame(&mut self, frame: Frame) -> Result<(), Error> {
213        // If the remote sent a syn, create a new stream and add it to the accept channel.
214        if frame.is_syn() {
215            let (req, stream) = OpenReq::create(self.window, false);
216            self.manager
217                .lock()
218                .await
219                .create_stream(frame.header.stream_id.into(), req)?;
220            self.accept_tx
221                .send(stream)
222                .map_err(|_| Error::SessionClosed)
223                .await?;
224        }
225
226        let needs_close = frame.is_fin();
227
228        let Frame {
229            header:
230                Header {
231                    length: _,
232                    flags: _,
233                    stream_id,
234                    typ,
235                },
236            ..
237        } = frame;
238
239        match typ {
240            // These frame types are stream-specific
241            HeaderType::Data | HeaderType::Rst | HeaderType::WndInc => {
242                if let Err(error) = self.manager.send_to_stream(frame).await {
243                    // If the stream manager couldn't send this frame to the
244                    // stream for some reason, generate an RST to tell the other
245                    // end to stop sending on this stream.
246                    debug!(
247                        stream_id = display(stream_id),
248                        error = display(error),
249                        "error sending to stream, generating rst"
250                    );
251                    self.sys_tx
252                        .send(Frame::rst(stream_id, error))
253                        .map_err(|_| Error::SessionClosed)
254                        .await?;
255                } else {
256                    self.last_stream_processed = stream_id;
257                    if needs_close {
258                        if let Ok(handle) = self.manager.lock().await.get_stream(stream_id) {
259                            handle.data_write_closed = true;
260                        }
261                    }
262                }
263            }
264
265            // GoAway is a system-level frame, so send it along the special
266            // system channel.
267            HeaderType::GoAway => {
268                if let Body::GoAway { error, .. } = frame.body {
269                    self.manager.go_away(error).await;
270                    return Err(Error::RemoteGoneAway);
271                }
272
273                unreachable!()
274            }
275            HeaderType::Invalid(_) => {
276                self.sys_tx
277                    .send(Frame::goaway(
278                        self.last_stream_processed,
279                        Error::Protocol,
280                        "invalid frame".into(),
281                    ))
282                    .map_err(|_| Error::StreamClosed)
283                    .await?
284            }
285        }
286        Ok(())
287    }
288
289    // The actual read/process loop
290    async fn run(mut self) -> Result<(), Error> {
291        let _e: Result<(), _> = async {
292            loop {
293                match self.io.try_next().await {
294                    Ok(Some(frame)) => {
295                        trace!(?frame, "received frame from remote");
296                        self.handle_frame(frame).await?
297                    }
298                    Ok(None) | Err(_) => {
299                        return Err(Error::SessionClosed);
300                    }
301                }
302            }
303        }
304        .await;
305
306        self.manager.close_senders().await;
307
308        Err(Error::SessionClosed)
309    }
310}
311
312// The writer task responsible for receiving frames from streams or open
313// requests and writing them to the underlying stream.
314struct Writer<W> {
315    manager: SharedStreamManager,
316    window: usize,
317    open_reqs: mpsc::Receiver<oneshot::Sender<Result<Stream, Error>>>,
318    io: W,
319}
320
321impl<W> Writer<W>
322where
323    W: Sink<Frame, Error = io::Error> + Unpin + Send + 'static,
324{
325    async fn run(mut self) -> Result<(), Error> {
326        loop {
327            select! {
328                // The stream manager produced a frame that needs to be sent to
329                // the remote.
330                frame = self.manager.next() => {
331                    if let Some(frame) = frame {
332                        let is_goaway = matches!(frame.header.typ, HeaderType::GoAway);
333                        trace!(?frame, "sending frame to remote");
334                        if let Err(_e) = self.io.send(frame).await {
335                            return Err(Error::SessionClosed);
336                        }
337                        if is_goaway {
338                            return Ok(())
339                        }
340                    }
341                },
342                // If a request for a new stream originated locally, tell the
343                // stream manager to create it. The first dataframe from it will
344                // have the SYN flag set.
345                req = self.open_reqs.next() => {
346                    if let Some(resp_tx) = req {
347                        let (req, stream) = OpenReq::create(self.window, true);
348
349                        let mut manager = self.manager.lock().await;
350                        let res = manager.create_stream(None, req);
351                        let _ = resp_tx.send(res.map(move |_| stream));
352                    }
353                },
354                // All senders have been dropped - exit.
355                complete => {
356                    return Ok(());
357                }
358            }
359        }
360    }
361}
362
363/// A muxado session.
364///
365/// Can be used directly to open and accept streams, or split into dedicated
366/// open/accept parts.
367pub trait Session: Accept + OpenClose {
368    /// The open half of the session.
369    type OpenClose: OpenClose;
370    /// The accept half of the session.
371    type Accept: Accept;
372    /// Split the session into dedicated open/accept components.
373    fn split(self) -> (Self::OpenClose, Self::Accept);
374}
375
376/// Trait for accepting incoming streams in a muxado [Session].
377#[async_trait]
378pub trait Accept {
379    /// Accept an incoming stream that was opened by the remote.
380    async fn accept(&mut self) -> Option<Stream>;
381}
382
383/// Trait for opening new streams in a muxado [Session].
384#[async_trait]
385pub trait OpenClose {
386    /// Open a new stream.
387    async fn open(&mut self) -> Result<Stream, Error>;
388    /// Close the session by sending a GOAWAY
389    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error>;
390}
391
392/// The [Open] half of a muxado session.
393#[derive(Clone)]
394pub struct MuxadoOpen {
395    dropref: awaitdrop::Ref,
396    open_tx: mpsc::Sender<oneshot::Sender<Result<Stream, Error>>>,
397    sys_tx: mpsc::Sender<Frame>,
398    closed: Arc<AtomicBool>,
399}
400
401/// The [Accept] half of a muxado session.
402pub struct MuxadoAccept(#[allow(dead_code)] awaitdrop::Ref, mpsc::Receiver<Stream>);
403
404#[async_trait]
405impl Accept for MuxadoAccept {
406    async fn accept(&mut self) -> Option<Stream> {
407        self.1.next().await
408    }
409}
410
411#[async_trait]
412impl OpenClose for MuxadoOpen {
413    async fn open(&mut self) -> Result<Stream, Error> {
414        if self.closed.load(Ordering::SeqCst) {
415            return Err(Error::SessionClosed);
416        }
417        let (resp_tx, resp_rx) = oneshot::channel();
418
419        self.open_tx
420            .send(resp_tx)
421            .await
422            .map_err(|_| Error::SessionClosed)?;
423
424        let mut res = resp_rx
425            .await
426            .map_err(|_| Error::SessionClosed)
427            .and_then(|r| r);
428
429        if let Ok(ref mut stream) = &mut res {
430            stream.dropref = self.dropref.clone().into();
431        }
432
433        res
434    }
435
436    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
437        let res = self
438            .sys_tx
439            .send(Frame::goaway(
440                StreamID::clamp(0),
441                error,
442                msg.into_bytes().into(),
443            ))
444            .await
445            .map_err(|_| Error::SessionClosed);
446        self.closed.store(true, Ordering::SeqCst);
447        res
448    }
449}
450
451/// The base muxado [Session] implementation.
452///
453/// See the [Session], [Accept], and [Open] trait implementations for
454/// available methods.
455pub struct MuxadoSession {
456    incoming: MuxadoAccept,
457    outgoing: MuxadoOpen,
458}
459
460#[async_trait]
461impl Accept for MuxadoSession {
462    async fn accept(&mut self) -> Option<Stream> {
463        self.incoming.accept().await
464    }
465}
466
467#[async_trait]
468impl OpenClose for MuxadoSession {
469    async fn open(&mut self) -> Result<Stream, Error> {
470        self.outgoing.open().await
471    }
472
473    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
474        self.outgoing.close(error, msg).await
475    }
476}
477
478impl Session for MuxadoSession {
479    type Accept = MuxadoAccept;
480    type OpenClose = MuxadoOpen;
481    fn split(self) -> (Self::OpenClose, Self::Accept) {
482        (self.outgoing, self.incoming)
483    }
484}
485
486#[cfg(test)]
487mod test {
488    use tokio::io::{
489        self,
490        AsyncReadExt,
491        AsyncWriteExt,
492    };
493
494    use super::*;
495    #[tokio::test]
496    async fn test_session() {
497        let (left, right) = io::duplex(512);
498        let mut server = SessionBuilder::new(left).server().start();
499        let mut client = SessionBuilder::new(right).client().start();
500
501        tokio::spawn(async move {
502            let mut stream = server.accept().await.expect("accept stream");
503            let mut buf = Vec::new();
504            stream.read_to_end(&mut buf).await.expect("read stream");
505            drop(stream);
506            let mut stream = server.open().await.expect("open stream");
507            stream.write_all(&buf).await.expect("write to stream");
508        });
509
510        let mut stream = client.open().await.expect("open stream");
511        stream
512            .write_all(b"Hello, world!")
513            .await
514            .expect("write to stream");
515        drop(stream);
516
517        let mut stream = client.accept().await.expect("accept stream");
518        let mut buf = Vec::new();
519        stream
520            .read_to_end(&mut buf)
521            .await
522            .expect("read from stream");
523
524        assert_eq!(b"Hello, world!", &*buf,);
525    }
526}