muxado/
stream_output.rs

1use std::{
2    pin::Pin,
3    sync::{
4        atomic::{
5            AtomicU32,
6            Ordering,
7        },
8        Arc,
9    },
10    task::{
11        Context,
12        Poll,
13    },
14};
15
16use futures::{
17    channel::mpsc::{
18        self,
19        SendError,
20    },
21    ready,
22    sink::Sink,
23};
24use tracing::instrument;
25
26use crate::{
27    errors::Error,
28    frame::{
29        ErrorCode,
30        Frame,
31    },
32};
33
34pub struct StreamSender {
35    sink: mpsc::Sender<Frame>,
36    closer: SinkCloser,
37}
38
39impl StreamSender {
40    pub fn sink(&mut self) -> Pin<&mut mpsc::Sender<Frame>> {
41        Pin::new(&mut self.sink)
42    }
43
44    pub fn wrap(sink: mpsc::Sender<Frame>) -> StreamSender {
45        let code = Arc::new(AtomicU32::new(0));
46        StreamSender {
47            sink,
48            closer: SinkCloser { code },
49        }
50    }
51
52    pub fn closer(&self) -> SinkCloser {
53        self.closer.clone()
54    }
55}
56
57impl Sink<Frame> for StreamSender {
58    type Error = Error;
59
60    #[instrument(level = "trace", skip_all)]
61    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.closer.check_closed()?;
63        Poll::Ready(match ready!(self.as_mut().sink().poll_ready(cx)) {
64            Ok(()) => Ok(()),
65            Err(_) => {
66                // If there was an error here, it means the stream manager got
67                // dropped.
68                self.closer.close_with(Error::SessionClosed);
69                Err(Error::SessionClosed)
70            }
71        })
72    }
73    #[instrument(level = "trace", skip(self))]
74    fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
75        self.closer.check_closed()?;
76        match self.as_mut().sink().start_send(item) {
77            Ok(()) => Ok(()),
78            Err(_) => {
79                self.closer.close_with(Error::SessionClosed);
80                Err(Error::SessionClosed)
81            }
82        }
83    }
84    #[instrument(level = "trace", skip_all)]
85    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.closer.check_closed()?;
87        Poll::Ready(match ready!(self.as_mut().sink().poll_flush(cx)) {
88            Ok(()) => Ok(()),
89            Err(_) => {
90                self.closer.close_with(Error::SessionClosed);
91                Err(Error::SessionClosed)
92            }
93        })
94    }
95
96    // Note: This should never actually be called. The stream uses a sentinel
97    //       invalid frame to indicate closure rather than closing the channel.
98    #[instrument(level = "trace", skip_all)]
99    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        self.closer.check_closed()?;
101        // The receiving end of this expects that if the channel looks closed,
102        // there are no buffered messages to read. Make sure we're flushed
103        // before closing.
104        (|| -> Poll<Result<(), SendError>> {
105            ready!(self.as_mut().sink().poll_flush(cx))?;
106            ready!(self.as_mut().sink().poll_close(cx))?;
107            Ok(()).into()
108        })()
109        .map_ok(|_| self.closer.close_with(Error::StreamClosed))
110        .map_err(|_| {
111            self.closer.close_with(Error::SessionClosed);
112            Error::SessionClosed
113        })
114    }
115}
116
117#[derive(Clone, Debug)]
118pub struct SinkCloser {
119    code: Arc<AtomicU32>,
120}
121
122impl SinkCloser {
123    #[instrument(level = "trace")]
124    pub fn close_with(&self, ty: Error) {
125        // Only store an error if there wasn't already one.
126        // Discard the result since we don't really care to return it.
127        let _ = self.code.compare_exchange(
128            0,
129            *ErrorCode::from(ty),
130            Ordering::AcqRel,
131            Ordering::Relaxed,
132        );
133    }
134
135    pub fn is_closed(&self) -> bool {
136        self.check_closed().is_err()
137    }
138
139    pub fn check_closed(&self) -> Result<(), Error> {
140        let code = self.code.load(Ordering::Acquire);
141        if code != 0 {
142            Err(Error::from(ErrorCode::mask(code)))
143        } else {
144            Ok(())
145        }
146    }
147}