1use std::{
2 pin::Pin,
3 sync::{
4 atomic::{
5 AtomicU32,
6 Ordering,
7 },
8 Arc,
9 },
10 task::{
11 Context,
12 Poll,
13 },
14};
15
16use futures::{
17 channel::mpsc::{
18 self,
19 SendError,
20 },
21 ready,
22 sink::Sink,
23};
24use tracing::instrument;
25
26use crate::{
27 errors::Error,
28 frame::{
29 ErrorCode,
30 Frame,
31 },
32};
33
34pub struct StreamSender {
35 sink: mpsc::Sender<Frame>,
36 closer: SinkCloser,
37}
38
39impl StreamSender {
40 pub fn sink(&mut self) -> Pin<&mut mpsc::Sender<Frame>> {
41 Pin::new(&mut self.sink)
42 }
43
44 pub fn wrap(sink: mpsc::Sender<Frame>) -> StreamSender {
45 let code = Arc::new(AtomicU32::new(0));
46 StreamSender {
47 sink,
48 closer: SinkCloser { code },
49 }
50 }
51
52 pub fn closer(&self) -> SinkCloser {
53 self.closer.clone()
54 }
55}
56
57impl Sink<Frame> for StreamSender {
58 type Error = Error;
59
60 #[instrument(level = "trace", skip_all)]
61 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.closer.check_closed()?;
63 Poll::Ready(match ready!(self.as_mut().sink().poll_ready(cx)) {
64 Ok(()) => Ok(()),
65 Err(_) => {
66 self.closer.close_with(Error::SessionClosed);
69 Err(Error::SessionClosed)
70 }
71 })
72 }
73 #[instrument(level = "trace", skip(self))]
74 fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
75 self.closer.check_closed()?;
76 match self.as_mut().sink().start_send(item) {
77 Ok(()) => Ok(()),
78 Err(_) => {
79 self.closer.close_with(Error::SessionClosed);
80 Err(Error::SessionClosed)
81 }
82 }
83 }
84 #[instrument(level = "trace", skip_all)]
85 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.closer.check_closed()?;
87 Poll::Ready(match ready!(self.as_mut().sink().poll_flush(cx)) {
88 Ok(()) => Ok(()),
89 Err(_) => {
90 self.closer.close_with(Error::SessionClosed);
91 Err(Error::SessionClosed)
92 }
93 })
94 }
95
96 #[instrument(level = "trace", skip_all)]
99 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 self.closer.check_closed()?;
101 (|| -> Poll<Result<(), SendError>> {
105 ready!(self.as_mut().sink().poll_flush(cx))?;
106 ready!(self.as_mut().sink().poll_close(cx))?;
107 Ok(()).into()
108 })()
109 .map_ok(|_| self.closer.close_with(Error::StreamClosed))
110 .map_err(|_| {
111 self.closer.close_with(Error::SessionClosed);
112 Error::SessionClosed
113 })
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct SinkCloser {
119 code: Arc<AtomicU32>,
120}
121
122impl SinkCloser {
123 #[instrument(level = "trace")]
124 pub fn close_with(&self, ty: Error) {
125 let _ = self.code.compare_exchange(
128 0,
129 *ErrorCode::from(ty),
130 Ordering::AcqRel,
131 Ordering::Relaxed,
132 );
133 }
134
135 pub fn is_closed(&self) -> bool {
136 self.check_closed().is_err()
137 }
138
139 pub fn check_closed(&self) -> Result<(), Error> {
140 let code = self.code.load(Ordering::Acquire);
141 if code != 0 {
142 Err(Error::from(ErrorCode::mask(code)))
143 } else {
144 Ok(())
145 }
146 }
147}