1use std::{
4 fmt,
5 ops::{
6 Deref,
7 DerefMut,
8 RangeInclusive,
9 },
10};
11
12use async_trait::async_trait;
13use bytes::{
14 Buf,
15 BufMut,
16};
17use tokio::io::{
18 AsyncReadExt,
19 AsyncWriteExt,
20};
21use tracing::debug;
22
23use crate::{
24 constrained::*,
25 errors::Error,
26 session::{
27 Accept,
28 OpenClose,
29 Session,
30 },
31 stream::Stream,
32};
33
34constrained_num! {
35 StreamType, u32, 0..=u32::MAX, clamp
37}
38
39#[derive(Clone)]
42pub struct Typed<S> {
43 inner: S,
44}
45
46impl<S> DerefMut for Typed<S> {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 &mut self.inner
49 }
50}
51
52impl<S> Deref for Typed<S> {
53 type Target = S;
54 fn deref(&self) -> &Self::Target {
55 &self.inner
56 }
57}
58
59impl<S> Typed<S>
60where
61 S: Session,
62{
63 pub fn new(inner: S) -> Self {
65 Typed { inner }
66 }
67}
68
69pub struct TypedStream {
71 typ: StreamType,
72 inner: Stream,
73}
74
75impl DerefMut for TypedStream {
76 fn deref_mut(&mut self) -> &mut Self::Target {
77 &mut self.inner
78 }
79}
80
81impl Deref for TypedStream {
82 type Target = Stream;
83 fn deref(&self) -> &Self::Target {
84 &self.inner
85 }
86}
87
88impl TypedStream {
89 pub fn typ(&self) -> StreamType {
91 self.typ
92 }
93}
94
95pub trait TypedSession: TypedAccept + TypedOpenClose {
97 type TypedAccept: TypedAccept;
99 type TypedOpen: TypedOpenClose;
101
102 fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept);
104}
105
106#[async_trait]
108pub trait TypedAccept {
109 async fn accept_typed(&mut self) -> Result<TypedStream, Error>;
115}
116
117#[async_trait]
119pub trait TypedOpenClose {
120 async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error>;
122 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error>;
124}
125
126#[async_trait]
127impl<S> TypedAccept for Typed<S>
128where
129 S: Accept + Send,
130{
131 async fn accept_typed(&mut self) -> Result<TypedStream, Error> {
132 let mut stream = self.accept().await.ok_or(Error::SessionClosed)?;
133
134 let mut buf = [0u8; 4];
135
136 stream
137 .read_exact(&mut buf[..])
138 .await
139 .map_err(|_| Error::StreamClosed)?;
140
141 let typ = StreamType::clamp((&buf[..]).get_u32());
142
143 debug!(?typ, "read stream type");
144
145 Ok(TypedStream { typ, inner: stream })
146 }
147}
148#[async_trait]
149impl<S> TypedOpenClose for Typed<S>
150where
151 S: OpenClose + Send,
152{
153 async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error> {
154 let mut stream = self.open().await?;
155
156 let mut bytes = [0u8; 4];
157 (&mut bytes[..]).put_u32(*typ);
158
159 stream
160 .write(&bytes[..])
161 .await
162 .map_err(|_| Error::StreamReset)?;
163
164 Ok(TypedStream { inner: stream, typ })
165 }
166
167 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
168 self.inner.close(error, msg).await
169 }
170}
171
172impl<S> TypedSession for Typed<S>
173where
174 S: Session + Send,
175 S::Accept: Send,
176 S::OpenClose: Send,
177{
178 type TypedAccept = Typed<S::Accept>;
179 type TypedOpen = Typed<S::OpenClose>;
180 fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept) {
181 let (open, accept) = self.inner.split();
182 (Typed { inner: open }, Typed { inner: accept })
183 }
184}