1use std::{
2 io,
3 sync::{
4 atomic::{
5 AtomicBool,
6 Ordering,
7 },
8 Arc,
9 },
10};
11
12use async_trait::async_trait;
13use futures::{
14 channel::{
15 mpsc,
16 oneshot,
17 },
18 prelude::*,
19 select,
20 stream::StreamExt,
21 SinkExt,
22};
23use tokio::io::{
24 AsyncRead,
25 AsyncWrite,
26};
27use tokio_util::codec::Framed;
28use tracing::{
29 debug,
30 debug_span,
31 instrument,
32 trace,
33 Instrument,
34};
35
36use crate::{
37 codec::FrameCodec,
38 errors::Error,
39 frame::{
40 Body,
41 Frame,
42 Header,
43 HeaderType,
44 StreamID,
45 },
46 stream::Stream,
47 stream_manager::{
48 OpenReq,
49 SharedStreamManager,
50 StreamManager,
51 },
52};
53
54const DEFAULT_WINDOW: usize = 0x40000; const DEFAULT_ACCEPT: usize = 64;
56const DEFAULT_STREAMS: usize = 512;
57
58pub struct SessionBuilder<S> {
63 io_stream: S,
64 window: usize,
65 accept_queue_size: usize,
66 stream_limit: usize,
67 client: bool,
68}
69
70impl<S> SessionBuilder<S>
71where
72 S: AsyncRead + AsyncWrite + Send + 'static,
73{
74 pub fn new(io_stream: S) -> Self {
76 SessionBuilder {
77 io_stream,
78 window: DEFAULT_WINDOW,
79 accept_queue_size: DEFAULT_ACCEPT,
80 stream_limit: DEFAULT_STREAMS,
81 client: true,
82 }
83 }
84
85 pub fn window_size(mut self, size: usize) -> Self {
88 self.window = size;
89 self
90 }
91
92 pub fn accept_queue_size(mut self, size: usize) -> Self {
98 self.accept_queue_size = size;
99 self
100 }
101
102 pub fn stream_limit(mut self, count: usize) -> Self {
106 self.stream_limit = count;
107 self
108 }
109
110 pub fn client(mut self) -> Self {
112 self.client = true;
113 self
114 }
115
116 pub fn server(mut self) -> Self {
118 self.client = false;
119 self
120 }
121
122 pub fn start(self) -> MuxadoSession {
124 let SessionBuilder {
125 io_stream,
126 window,
127 accept_queue_size,
128 stream_limit,
129 client,
130 } = self;
131
132 let (accept_tx, accept_rx) = mpsc::channel(accept_queue_size);
133 let (open_tx, open_rx) = mpsc::channel(512);
134
135 let manager = StreamManager::new(stream_limit, client);
136 let sys_tx = manager.sys_sender();
137 let (m1, m2) = manager.split();
138
139 let (io_tx, io_rx) = Framed::new(io_stream, FrameCodec::default()).split();
140
141 let read_task = Reader {
142 io: io_rx,
143 accept_tx,
144 window,
145 manager: m1,
146 last_stream_processed: StreamID::clamp(0),
147 sys_tx: sys_tx.clone(),
148 };
149
150 let write_task = Writer {
151 window,
152 io: io_tx,
153 manager: m2,
154 open_reqs: open_rx,
155 };
156
157 let (dropref, waiter) = awaitdrop::awaitdrop();
158
159 tokio::spawn(
160 futures::future::select(
161 async move {
162 let result = read_task.run().await;
163 debug!(?result, "read_task exited");
164 }
165 .boxed(),
166 waiter.wait(),
167 )
168 .instrument(debug_span!("read_task")),
169 );
170 tokio::spawn(
171 futures::future::select(
172 async move {
173 let result = write_task.run().await;
174 debug!(?result, "write_task exited");
175 }
176 .boxed(),
177 waiter.wait(),
178 )
179 .instrument(debug_span!("write_task")),
180 );
181
182 MuxadoSession {
183 incoming: MuxadoAccept(dropref.clone(), accept_rx),
184 outgoing: MuxadoOpen {
185 dropref,
186 open_tx,
187 sys_tx,
188 closed: AtomicBool::from(false).into(),
189 },
190 }
191 }
192}
193
194struct Reader<R> {
198 io: R,
199 sys_tx: mpsc::Sender<Frame>,
200 accept_tx: mpsc::Sender<Stream>,
201 window: usize,
202 manager: SharedStreamManager,
203 last_stream_processed: StreamID,
204}
205
206impl<R> Reader<R>
207where
208 R: futures::stream::Stream<Item = Result<Frame, io::Error>> + Unpin,
209{
210 #[instrument(level = "trace", skip(self))]
212 async fn handle_frame(&mut self, frame: Frame) -> Result<(), Error> {
213 if frame.is_syn() {
215 let (req, stream) = OpenReq::create(self.window, false);
216 self.manager
217 .lock()
218 .await
219 .create_stream(frame.header.stream_id.into(), req)?;
220 self.accept_tx
221 .send(stream)
222 .map_err(|_| Error::SessionClosed)
223 .await?;
224 }
225
226 let needs_close = frame.is_fin();
227
228 let Frame {
229 header:
230 Header {
231 length: _,
232 flags: _,
233 stream_id,
234 typ,
235 },
236 ..
237 } = frame;
238
239 match typ {
240 HeaderType::Data | HeaderType::Rst | HeaderType::WndInc => {
242 if let Err(error) = self.manager.send_to_stream(frame).await {
243 debug!(
247 stream_id = display(stream_id),
248 error = display(error),
249 "error sending to stream, generating rst"
250 );
251 self.sys_tx
252 .send(Frame::rst(stream_id, error))
253 .map_err(|_| Error::SessionClosed)
254 .await?;
255 } else {
256 self.last_stream_processed = stream_id;
257 if needs_close {
258 if let Ok(handle) = self.manager.lock().await.get_stream(stream_id) {
259 handle.data_write_closed = true;
260 }
261 }
262 }
263 }
264
265 HeaderType::GoAway => {
268 if let Body::GoAway { error, .. } = frame.body {
269 self.manager.go_away(error).await;
270 return Err(Error::RemoteGoneAway);
271 }
272
273 unreachable!()
274 }
275 HeaderType::Invalid(_) => {
276 self.sys_tx
277 .send(Frame::goaway(
278 self.last_stream_processed,
279 Error::Protocol,
280 "invalid frame".into(),
281 ))
282 .map_err(|_| Error::StreamClosed)
283 .await?
284 }
285 }
286 Ok(())
287 }
288
289 async fn run(mut self) -> Result<(), Error> {
291 let _e: Result<(), _> = async {
292 loop {
293 match self.io.try_next().await {
294 Ok(Some(frame)) => {
295 trace!(?frame, "received frame from remote");
296 self.handle_frame(frame).await?
297 }
298 Ok(None) | Err(_) => {
299 return Err(Error::SessionClosed);
300 }
301 }
302 }
303 }
304 .await;
305
306 self.manager.close_senders().await;
307
308 Err(Error::SessionClosed)
309 }
310}
311
312struct Writer<W> {
315 manager: SharedStreamManager,
316 window: usize,
317 open_reqs: mpsc::Receiver<oneshot::Sender<Result<Stream, Error>>>,
318 io: W,
319}
320
321impl<W> Writer<W>
322where
323 W: Sink<Frame, Error = io::Error> + Unpin + Send + 'static,
324{
325 async fn run(mut self) -> Result<(), Error> {
326 loop {
327 select! {
328 frame = self.manager.next() => {
331 if let Some(frame) = frame {
332 let is_goaway = matches!(frame.header.typ, HeaderType::GoAway);
333 trace!(?frame, "sending frame to remote");
334 if let Err(_e) = self.io.send(frame).await {
335 return Err(Error::SessionClosed);
336 }
337 if is_goaway {
338 return Ok(())
339 }
340 }
341 },
342 req = self.open_reqs.next() => {
346 if let Some(resp_tx) = req {
347 let (req, stream) = OpenReq::create(self.window, true);
348
349 let mut manager = self.manager.lock().await;
350 let res = manager.create_stream(None, req);
351 let _ = resp_tx.send(res.map(move |_| stream));
352 }
353 },
354 complete => {
356 return Ok(());
357 }
358 }
359 }
360 }
361}
362
363pub trait Session: Accept + OpenClose {
368 type OpenClose: OpenClose;
370 type Accept: Accept;
372 fn split(self) -> (Self::OpenClose, Self::Accept);
374}
375
376#[async_trait]
378pub trait Accept {
379 async fn accept(&mut self) -> Option<Stream>;
381}
382
383#[async_trait]
385pub trait OpenClose {
386 async fn open(&mut self) -> Result<Stream, Error>;
388 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error>;
390}
391
392#[derive(Clone)]
394pub struct MuxadoOpen {
395 dropref: awaitdrop::Ref,
396 open_tx: mpsc::Sender<oneshot::Sender<Result<Stream, Error>>>,
397 sys_tx: mpsc::Sender<Frame>,
398 closed: Arc<AtomicBool>,
399}
400
401pub struct MuxadoAccept(#[allow(dead_code)] awaitdrop::Ref, mpsc::Receiver<Stream>);
403
404#[async_trait]
405impl Accept for MuxadoAccept {
406 async fn accept(&mut self) -> Option<Stream> {
407 self.1.next().await
408 }
409}
410
411#[async_trait]
412impl OpenClose for MuxadoOpen {
413 async fn open(&mut self) -> Result<Stream, Error> {
414 if self.closed.load(Ordering::SeqCst) {
415 return Err(Error::SessionClosed);
416 }
417 let (resp_tx, resp_rx) = oneshot::channel();
418
419 self.open_tx
420 .send(resp_tx)
421 .await
422 .map_err(|_| Error::SessionClosed)?;
423
424 let mut res = resp_rx
425 .await
426 .map_err(|_| Error::SessionClosed)
427 .and_then(|r| r);
428
429 if let Ok(ref mut stream) = &mut res {
430 stream.dropref = self.dropref.clone().into();
431 }
432
433 res
434 }
435
436 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
437 let res = self
438 .sys_tx
439 .send(Frame::goaway(
440 StreamID::clamp(0),
441 error,
442 msg.into_bytes().into(),
443 ))
444 .await
445 .map_err(|_| Error::SessionClosed);
446 self.closed.store(true, Ordering::SeqCst);
447 res
448 }
449}
450
451pub struct MuxadoSession {
456 incoming: MuxadoAccept,
457 outgoing: MuxadoOpen,
458}
459
460#[async_trait]
461impl Accept for MuxadoSession {
462 async fn accept(&mut self) -> Option<Stream> {
463 self.incoming.accept().await
464 }
465}
466
467#[async_trait]
468impl OpenClose for MuxadoSession {
469 async fn open(&mut self) -> Result<Stream, Error> {
470 self.outgoing.open().await
471 }
472
473 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
474 self.outgoing.close(error, msg).await
475 }
476}
477
478impl Session for MuxadoSession {
479 type Accept = MuxadoAccept;
480 type OpenClose = MuxadoOpen;
481 fn split(self) -> (Self::OpenClose, Self::Accept) {
482 (self.outgoing, self.incoming)
483 }
484}
485
486#[cfg(test)]
487mod test {
488 use tokio::io::{
489 self,
490 AsyncReadExt,
491 AsyncWriteExt,
492 };
493
494 use super::*;
495 #[tokio::test]
496 async fn test_session() {
497 let (left, right) = io::duplex(512);
498 let mut server = SessionBuilder::new(left).server().start();
499 let mut client = SessionBuilder::new(right).client().start();
500
501 tokio::spawn(async move {
502 let mut stream = server.accept().await.expect("accept stream");
503 let mut buf = Vec::new();
504 stream.read_to_end(&mut buf).await.expect("read stream");
505 drop(stream);
506 let mut stream = server.open().await.expect("open stream");
507 stream.write_all(&buf).await.expect("write to stream");
508 });
509
510 let mut stream = client.open().await.expect("open stream");
511 stream
512 .write_all(b"Hello, world!")
513 .await
514 .expect("write to stream");
515 drop(stream);
516
517 let mut stream = client.accept().await.expect("accept stream");
518 let mut buf = Vec::new();
519 stream
520 .read_to_end(&mut buf)
521 .await
522 .expect("read from stream");
523
524 assert_eq!(b"Hello, world!", &*buf,);
525 }
526}