1use std::{
2 collections::HashMap,
3 pin::Pin,
4 sync::{
5 atomic::{
6 AtomicBool,
7 Ordering,
8 },
9 Arc,
10 },
11 task::{
12 Context,
13 Poll,
14 Waker,
15 },
16};
17
18use futures::{
19 channel::mpsc,
20 future::poll_fn,
21 lock::{
22 BiLock,
23 BiLockGuard,
24 },
25 prelude::*,
26 ready,
27 stream::{
28 FusedStream,
29 FuturesUnordered,
30 Stream as StreamT,
31 StreamFuture,
32 },
33};
34use tracing::{
35 debug,
36 error,
37 instrument,
38 trace,
39};
40
41use crate::{
42 errors::Error,
43 frame::{
44 Body,
45 Frame,
46 HeaderType,
47 StreamID,
48 },
49 stream::Stream,
50 stream_output::*,
51};
52
53#[derive(Debug)]
54pub struct SharedStreamManager(BiLock<StreamManager>, Arc<AtomicBool>);
55
56#[derive(Clone, Debug)]
57pub(crate) struct StreamHandle {
58 pub to_stream: mpsc::Sender<Frame>,
60
61 pub sink_closer: SinkCloser,
64
65 pub needs_fin: bool,
66
67 pub data_write_closed: bool,
69
70 pub window: usize,
73}
74
75type StreamTasks = FuturesUnordered<WithID<StreamFuture<mpsc::Receiver<Frame>>>>;
76
77#[derive(Debug)]
78pub struct StreamManager {
79 stream_limit: usize,
80
81 streams: HashMap<StreamID, StreamHandle>,
82 sys_tx: mpsc::Sender<Frame>,
83 sys_rx: mpsc::Receiver<Frame>,
84 tasks: StreamTasks,
85
86 last_local_id: StreamID,
87 last_remote_id: StreamID,
88
89 gone_away: bool,
90
91 new_streams: Option<Waker>,
97}
98
99impl StreamManager {
100 fn tasks(&mut self) -> Pin<&mut StreamTasks> {
101 Pin::new(&mut self.tasks)
102 }
103
104 fn sys_rx(&mut self) -> Pin<&mut mpsc::Receiver<Frame>> {
105 Pin::new(&mut self.sys_rx)
106 }
107
108 pub fn new(stream_limit: usize, client: bool) -> Self {
109 let (sys_tx, sys_rx) = mpsc::channel(512);
110 let mut last_local_id = 0;
111 let mut last_remote_id = 0;
112 if client {
113 last_local_id += 1;
114 } else {
115 last_remote_id += 1;
116 }
117 StreamManager {
118 streams: Default::default(),
119 stream_limit,
120 sys_tx,
121 sys_rx,
122 last_local_id: StreamID::clamp(last_local_id),
123 last_remote_id: StreamID::clamp(last_remote_id),
124 tasks: Default::default(),
125 gone_away: false,
126 new_streams: None,
127 }
128 }
129
130 pub fn split(self) -> (SharedStreamManager, SharedStreamManager) {
132 let (l, r) = BiLock::new(self);
133 let terminated = Arc::new(AtomicBool::new(false));
134 (
135 SharedStreamManager(l, terminated.clone()),
136 SharedStreamManager(r, terminated),
137 )
138 }
139
140 pub fn go_away(&mut self, error: Error) {
141 self.gone_away = true;
142 for (_id, handle) in self.streams.drain() {
143 handle.sink_closer.close_with(error);
144 }
145 }
146
147 pub fn sys_sender(&self) -> mpsc::Sender<Frame> {
148 self.sys_tx.clone()
149 }
150
151 pub fn close_senders(&mut self) {
152 for (_, stream) in self.streams.iter_mut() {
153 stream.to_stream.close_channel()
154 }
155 }
156}
157
158impl FusedStream for StreamManager {
159 fn is_terminated(&self) -> bool {
160 self.gone_away
161 }
162}
163
164impl FusedStream for SharedStreamManager {
165 fn is_terminated(&self) -> bool {
166 self.1.load(Ordering::SeqCst)
167 }
168}
169
170impl StreamT for StreamManager {
174 type Item = Frame;
175
176 #[instrument(level = "trace", skip_all)]
177 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 if self.gone_away {
180 return Poll::Ready(None);
181 }
182
183 self.new_streams = Some(cx.waker().clone());
187
188 if let Poll::Ready(Some(mut frame)) = self.as_mut().sys_rx().poll_next(cx) {
191 if let Body::GoAway {
192 ref mut last_stream_id,
193 error: _,
194 message: _,
195 } = &mut frame.body
196 {
197 *last_stream_id = self.last_remote_id;
198 self.as_mut().tasks().clear();
200 self.as_mut().go_away(Error::SessionClosed);
201 }
202 return Some(frame).into();
203 }
204
205 let (id, (item, rest)) = if let Some(i) = ready!(self.as_mut().tasks().poll_next(cx)) {
207 i
208 } else {
209 return Poll::Pending;
210 };
211
212 let handle = if let Ok(handle) = self.get_stream(id) {
213 handle
214 } else {
215 unreachable!();
219 };
220
221 if handle.sink_closer.is_closed() && !handle.needs_fin {
228 debug!(needs_fin = handle.needs_fin, "removing stream without fin");
229 self.remove_stream(id);
230 cx.waker().wake_by_ref();
231 return Poll::Pending;
232 }
233
234 let frame = if let Some(frame) = item {
235 if let Body::WndInc(inc) = frame.body {
236 handle.window += *inc as usize;
237 }
238
239 if frame.is_fin() {
240 debug!(stream_id = debug(id), "setting needs_fin to false");
241 handle.needs_fin = false;
242 }
243
244 self.push_task(id, rest);
245
246 frame
247 } else {
248 let needs_fin = handle.needs_fin;
256 handle.needs_fin = false;
257 self.remove_stream(id);
258 debug!(needs_fin, "got none from stream, trying to send a fin");
259 if needs_fin {
260 debug!("removing stream and sending fin");
261 Frame::from(Body::Data([][..].into())).fin()
262 } else {
263 debug!("removing stream that's already fin'd");
264 cx.waker().wake_by_ref();
267 return Poll::Pending;
268 }
269 }
270 .stream_id(id);
271
272 Some(frame).into()
273 }
274}
275
276impl StreamT for SharedStreamManager {
277 type Item = <StreamManager as StreamT>::Item;
278 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279 ready!(self.0.poll_lock(cx)).as_pin_mut().poll_next(cx)
280 }
281}
282
283impl SharedStreamManager {
284 #[instrument(level = "trace", skip(self))]
285 pub async fn go_away(&mut self, error: Error) {
286 self.1.store(true, Ordering::SeqCst);
287 self.lock().await.go_away(error);
288 }
289
290 #[instrument(level = "trace", skip(self))]
293 pub async fn send_to_stream(&mut self, frame: Frame) -> Result<(), Error> {
294 let id = frame.header.stream_id;
295 let typ = frame.header.typ;
296 let mut shrink_window = if let Body::Data(bs) = &frame.body {
299 bs.len()
300 } else {
301 0
302 };
303
304 match frame.body {
305 Body::GoAway { .. } | Body::Invalid { .. } => {
306 error!(
307 body = ?frame,
308 id = %id,
309 "attempt to send invalid frame type to stream",
310 );
311 return Err(Error::Internal);
312 }
313 _ => {}
314 }
315
316 let is_fin = frame.is_fin() && frame.body.len() == 0;
317
318 let mut frame = Some(frame);
319
320 let mut handle_frame = |handle: &mut StreamHandle, cx: &mut Context| {
321 if typ == HeaderType::Data && handle.data_write_closed {
322 debug!("attempt to send data on closed stream");
323 return Err(Error::StreamClosed).into();
324 }
325
326 if let Some(Frame {
329 body: Body::Rst(err),
330 ..
331 }) = frame
332 {
333 debug!(
334 stream_id = display(id),
335 error = display(err),
336 "received rst from remote, closing stream"
337 );
338 handle.sink_closer.close_with(err);
341 handle.data_write_closed = true;
342 handle.needs_fin = false;
343 return Ok(()).into();
344 }
345
346 if shrink_window <= handle.window {
349 handle.window -= shrink_window;
350 shrink_window = 0;
353 } else {
354 debug!(
355 frame_size = shrink_window,
356 stream_window = handle.window,
357 "remote violated flow control"
358 );
359 return Err(Error::FlowControl).into();
360 }
361
362 let sink = &mut handle.to_stream;
363 trace!("checking stream for readiness");
364 ready!(sink.poll_ready(cx))
365 .and_then(|_| sink.start_send(frame.take().unwrap()))
366 .map_err(|_| Error::StreamClosed)
367 .or_else(|res| if is_fin { Ok(()) } else { Err(res) })?;
368 Ok(()).into()
369 };
370
371 poll_fn(move |cx| -> Poll<Result<_, Error>> {
376 let mut lock = ready!(self.0.poll_lock(cx));
379 let handle = match lock.get_stream(id) {
380 Ok(handle) => handle,
381 Err(_e) if HeaderType::Data != typ || is_fin => {
382 return Ok(()).into();
383 }
384 Err(e) => return Err(e).into(),
385 };
386
387 let res = ready!(handle_frame(handle, cx));
388
389 if HeaderType::Data == typ && !is_fin {
392 if let Err(e) = res {
395 debug!(error = display(e), "error handling frame");
396 handle.sink_closer.close_with(Error::StreamClosed);
397 handle.data_write_closed = true;
398 handle.needs_fin = false;
399 }
400 res.into()
401 } else {
402 Ok(()).into()
403 }
404 })
405 .await
406 }
407
408 pub async fn close_senders(&mut self) {
409 self.lock().await.close_senders()
410 }
411
412 pub async fn lock(&mut self) -> BiLockGuard<'_, StreamManager> {
413 self.0.lock().await
414 }
415}
416
417pub struct OpenReq {
418 channel: (mpsc::Sender<Frame>, mpsc::Receiver<Frame>),
419 closer: SinkCloser,
420 window: usize,
421}
422
423impl OpenReq {
424 pub fn create(window: usize, needs_syn: bool) -> (OpenReq, Stream) {
425 let (to_stream, from_session) = mpsc::channel(window);
426 let (to_session, from_stream) = mpsc::channel(window);
427 let to_session = StreamSender::wrap(to_session);
428 let req = OpenReq {
429 channel: (to_stream, from_stream),
430 closer: to_session.closer(),
431 window,
432 };
433 let stream = Stream::new(to_session, from_session, window, needs_syn);
434 (req, stream)
435 }
436}
437
438impl StreamManager {
439 #[instrument(level = "trace", skip(self))]
440 pub(crate) fn get_stream(&mut self, id: StreamID) -> Result<&mut StreamHandle, Error> {
441 if let Some(handle) = self.streams.get_mut(&id) {
442 Ok(handle)
443 } else {
444 trace!("stream not found");
445 Err(Error::StreamClosed)
446 }
447 }
448
449 #[instrument(level = "trace", skip(self, req))]
450 pub fn create_stream(&mut self, id: Option<StreamID>, req: OpenReq) -> Result<StreamID, Error> {
451 if self.streams.len() == self.stream_limit {
453 return Err(Error::StreamsExhausted);
454 }
455
456 let (to_stream, from_stream) = req.channel;
457 let closer = req.closer;
458 let window = req.window;
459 let id = if let Some(remote_id) = id {
460 self.last_remote_id = remote_id;
461 remote_id
462 } else {
463 let new_id = StreamID::clamp(*self.last_local_id + 2);
464 self.last_local_id = new_id;
465 new_id
466 };
467 self.streams.insert(
468 id,
469 StreamHandle {
470 window,
471 to_stream,
472 sink_closer: closer,
473 needs_fin: true,
474 data_write_closed: false,
475 },
476 );
477 self.push_task(id, from_stream);
478 if let Some(w) = self.new_streams.take() {
480 w.wake()
481 }
482 Ok(id)
483 }
484
485 fn push_task(&mut self, id: StreamID, recv: mpsc::Receiver<Frame>) {
486 self.tasks.push(recv.into_future().with_id(id));
487 }
488
489 #[instrument(level = "debug", skip(self))]
490 fn remove_stream(&mut self, id: StreamID) -> Option<StreamHandle> {
491 self.streams.remove(&id)
492 }
493}
494
495struct WithID<F: ?Sized> {
496 id: StreamID,
497 fut: F,
498}
499
500impl<F: Unpin> WithID<F> {
501 fn fut(&mut self) -> Pin<&mut F> {
502 Pin::new(&mut self.fut)
503 }
504}
505
506impl<F> Future for WithID<F>
507where
508 F: Future + Unpin,
509{
510 type Output = (StreamID, F::Output);
511
512 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
513 let out = ready!(self.as_mut().fut().poll(cx));
514 Poll::Ready((self.id, out))
515 }
516}
517
518trait WithIDExt {
519 fn with_id(self, id: StreamID) -> WithID<Self>;
520}
521
522impl<F> WithIDExt for F
523where
524 F: Future,
525{
526 fn with_id(self, id: StreamID) -> WithID<Self> {
527 WithID { id, fut: self }
528 }
529}