muxado/
typed.rs

1//! Wrappers to add typing to muxado streams.
2
3use std::{
4    fmt,
5    ops::{
6        Deref,
7        DerefMut,
8        RangeInclusive,
9    },
10};
11
12use async_trait::async_trait;
13use bytes::{
14    Buf,
15    BufMut,
16};
17use tokio::io::{
18    AsyncReadExt,
19    AsyncWriteExt,
20};
21use tracing::debug;
22
23use crate::{
24    constrained::*,
25    errors::Error,
26    session::{
27        Accept,
28        OpenClose,
29        Session,
30    },
31    stream::Stream,
32};
33
34constrained_num! {
35    /// A muxado stream type.
36    StreamType, u32, 0..=u32::MAX, clamp
37}
38
39/// Wrapper for a session capable of opening streams prefixed with a `u32` type
40/// id.
41#[derive(Clone)]
42pub struct Typed<S> {
43    inner: S,
44}
45
46impl<S> DerefMut for Typed<S> {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        &mut self.inner
49    }
50}
51
52impl<S> Deref for Typed<S> {
53    type Target = S;
54    fn deref(&self) -> &Self::Target {
55        &self.inner
56    }
57}
58
59impl<S> Typed<S>
60where
61    S: Session,
62{
63    /// Wrap a session for use with typed streams.
64    pub fn new(inner: S) -> Self {
65        Typed { inner }
66    }
67}
68
69/// A typed muxado stream.
70pub struct TypedStream {
71    typ: StreamType,
72    inner: Stream,
73}
74
75impl DerefMut for TypedStream {
76    fn deref_mut(&mut self) -> &mut Self::Target {
77        &mut self.inner
78    }
79}
80
81impl Deref for TypedStream {
82    type Target = Stream;
83    fn deref(&self) -> &Self::Target {
84        &self.inner
85    }
86}
87
88impl TypedStream {
89    /// Get the type ID for this stream.
90    pub fn typ(&self) -> StreamType {
91        self.typ
92    }
93}
94
95/// Typed analogue to the [Session] trait.
96pub trait TypedSession: TypedAccept + TypedOpenClose {
97    /// The component implementing [TypedAccept].
98    type TypedAccept: TypedAccept;
99    /// The component implementing [TypedOpen].
100    type TypedOpen: TypedOpenClose;
101
102    /// Split the typed session into open/accept components.
103    fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept);
104}
105
106/// Typed analogue to the [Accept] trait.
107#[async_trait]
108pub trait TypedAccept {
109    /// Accept a typed stream.
110    ///
111    /// Because typed streams are indistinguishable from untyped streams, if the
112    /// remote isn't sending a type, then the first 4 bytes of data will be
113    /// misinterpreted as the stream type.
114    async fn accept_typed(&mut self) -> Result<TypedStream, Error>;
115}
116
117/// Typed analogue to the [Open] trait.
118#[async_trait]
119pub trait TypedOpenClose {
120    /// Open a typed stream with the given type.
121    async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error>;
122    /// Close the session by sending a GOAWAY
123    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error>;
124}
125
126#[async_trait]
127impl<S> TypedAccept for Typed<S>
128where
129    S: Accept + Send,
130{
131    async fn accept_typed(&mut self) -> Result<TypedStream, Error> {
132        let mut stream = self.accept().await.ok_or(Error::SessionClosed)?;
133
134        let mut buf = [0u8; 4];
135
136        stream
137            .read_exact(&mut buf[..])
138            .await
139            .map_err(|_| Error::StreamClosed)?;
140
141        let typ = StreamType::clamp((&buf[..]).get_u32());
142
143        debug!(?typ, "read stream type");
144
145        Ok(TypedStream { typ, inner: stream })
146    }
147}
148#[async_trait]
149impl<S> TypedOpenClose for Typed<S>
150where
151    S: OpenClose + Send,
152{
153    async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error> {
154        let mut stream = self.open().await?;
155
156        let mut bytes = [0u8; 4];
157        (&mut bytes[..]).put_u32(*typ);
158
159        stream
160            .write(&bytes[..])
161            .await
162            .map_err(|_| Error::StreamReset)?;
163
164        Ok(TypedStream { inner: stream, typ })
165    }
166
167    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
168        self.inner.close(error, msg).await
169    }
170}
171
172impl<S> TypedSession for Typed<S>
173where
174    S: Session + Send,
175    S::Accept: Send,
176    S::OpenClose: Send,
177{
178    type TypedAccept = Typed<S::Accept>;
179    type TypedOpen = Typed<S::OpenClose>;
180    fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept) {
181        let (open, accept) = self.inner.split();
182        (Typed { inner: open }, Typed { inner: accept })
183    }
184}