muxado/
codec.rs

1use std::io::Write;
2
3use bytes::{
4    Buf,
5    BufMut,
6    BytesMut,
7};
8use tokio_util::codec::{
9    Decoder,
10    Encoder,
11};
12use tracing::instrument;
13
14use super::{
15    errors::InvalidHeader,
16    frame::*,
17};
18
19/// Codec for muxado frames.
20#[derive(Default, Debug)]
21pub struct FrameCodec {
22    // the header has to be read to know how big a frame is.
23    // We'll decode it once when we have enough bytes, and then wait for the
24    // rest, keeping the already-decoded header around in the meantime to avoid
25    // decoding it repeatedly.
26    input_header: Option<Header>,
27}
28
29#[instrument(level = "trace")]
30fn decode_header(mut bs: BytesMut) -> Header {
31    let length_type_flags = bs.get_u32();
32    let length = ((length_type_flags & 0xFFFFFF00) >> 8).try_into().unwrap();
33    let type_flags = length_type_flags as u8;
34
35    Header {
36        length,
37        typ: ((type_flags & 0xF0) >> 4).into(),
38        flags: Flags::from_bits_truncate(type_flags & 0x0F),
39        stream_id: StreamID::mask(bs.get_u32()),
40    }
41}
42
43fn expect_zero_stream_id(header: Header) -> Result<(), InvalidHeader> {
44    if header.stream_id != StreamID::clamp(0) {
45        Err(InvalidHeader::NonZeroStreamID(header.stream_id))
46    } else {
47        Ok(())
48    }
49}
50
51fn expect_non_zero_stream_id(header: Header) -> Result<(), InvalidHeader> {
52    if header.stream_id == StreamID::clamp(0) {
53        Err(InvalidHeader::ZeroStreamID)
54    } else {
55        Ok(())
56    }
57}
58
59fn expect_length(header: Header, length: Length) -> Result<(), InvalidHeader> {
60    if header.length != length {
61        Err(InvalidHeader::Length {
62            expected: length,
63            actual: header.length,
64        })
65    } else {
66        Ok(())
67    }
68}
69
70fn expect_min_length(header: Header, length: Length) -> Result<(), InvalidHeader> {
71    if header.length < length {
72        Err(InvalidHeader::MinLength {
73            expected: length,
74            actual: header.length,
75        })
76    } else {
77        Ok(())
78    }
79}
80
81#[instrument(level = "trace")]
82fn validate_header(header: Header) -> Result<(), InvalidHeader> {
83    match header.typ {
84        HeaderType::Rst => {
85            expect_non_zero_stream_id(header)?;
86            expect_length(header, Length::clamp(4))?;
87        }
88        HeaderType::Data => {
89            expect_non_zero_stream_id(header)?;
90        }
91        HeaderType::WndInc => {
92            expect_non_zero_stream_id(header)?;
93            expect_length(header, Length::clamp(4))?;
94        }
95        HeaderType::GoAway => {
96            expect_zero_stream_id(header)?;
97            expect_min_length(header, Length::clamp(8))?;
98        }
99        HeaderType::Invalid(t) => return Err(InvalidHeader::Type(t)),
100    }
101
102    Ok(())
103}
104
105#[instrument(level = "trace")]
106fn decode_frame(header: Header, mut body: BytesMut) -> Frame {
107    if let Err(error) = validate_header(header) {
108        return Frame {
109            header,
110            body: Body::Invalid {
111                error,
112                body: body.freeze(),
113            },
114        };
115    }
116
117    Frame {
118        header,
119        body: match header.typ {
120            HeaderType::Rst => Body::Rst(ErrorCode::mask(body.get_u32()).into()),
121            HeaderType::Data => Body::Data(body.freeze()),
122            HeaderType::WndInc => Body::WndInc(WndInc::mask(body.get_u32())),
123            HeaderType::GoAway => Body::GoAway {
124                last_stream_id: StreamID::mask(body.get_u32()),
125                error: ErrorCode::mask(body.get_u32()).into(),
126                message: body.freeze(),
127            },
128            HeaderType::Invalid(t) => Body::Invalid {
129                error: InvalidHeader::Type(t),
130                body: body.freeze(),
131            },
132        },
133    }
134}
135
136impl Decoder for FrameCodec {
137    type Item = Frame;
138    type Error = std::io::Error;
139
140    #[instrument(level = "trace")]
141    fn decode(&mut self, b: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
142        let header = if let Some(header) = self.input_header {
143            header
144        } else {
145            if b.len() < 8 {
146                return Ok(None);
147            }
148
149            let header = decode_header(b.split_to(8));
150            self.input_header = Some(header);
151            header
152        };
153
154        if b.len() < *header.length as usize {
155            return Ok(None);
156        }
157
158        let body_bytes = b.split_to(*header.length as usize);
159
160        // Drop the header to get ready for the next frame.
161        self.input_header.take();
162
163        Ok(Some(decode_frame(header, body_bytes)))
164    }
165}
166
167#[instrument(level = "trace")]
168fn encode_header(header: Header, buf: &mut BytesMut) {
169    // Pack the type into the upper nibble and flags into the lower.
170    let type_flags: u8 = ((u8::from(header.typ) << 4) & 0xF0) | (header.flags.bits() & 0x0F);
171    // Pack the 24-bit length and packed type & flags into a u32
172    let length_type_flags: u32 = (*header.length << 8 & 0xFFFFFF00) | type_flags as u32;
173
174    buf.put_u32(length_type_flags);
175    buf.put_u32(*header.stream_id);
176}
177
178#[instrument(level = "trace")]
179fn encode_body(body: Body, buf: &mut BytesMut) {
180    match body {
181        Body::Rst(err) => buf.put_u32(*ErrorCode::from(err)),
182        Body::Data(data) => buf.writer().write_all(&data).unwrap(),
183        Body::WndInc(inc) => buf.put_u32(*inc),
184        Body::GoAway {
185            last_stream_id,
186            error,
187            message,
188        } => {
189            buf.put_u32(*last_stream_id);
190            buf.put_u32(*ErrorCode::from(error));
191            buf.writer().write_all(&message).unwrap();
192        }
193        Body::Invalid { body, .. } => buf.writer().write_all(&body).unwrap(),
194    }
195}
196
197impl Encoder<Frame> for FrameCodec {
198    type Error = std::io::Error;
199
200    #[instrument(level = "trace")]
201    fn encode(&mut self, frame: Frame, buf: &mut BytesMut) -> Result<(), std::io::Error> {
202        validate_header(frame.header)?;
203        encode_header(frame.header, buf);
204        encode_body(frame.body, buf);
205        Ok(())
206    }
207}
208
209#[cfg(test)]
210mod test {
211    use bytes::Bytes;
212
213    use super::*;
214
215    #[test]
216    fn round_trip() {
217        let frame = Frame::from(Body::Data(Bytes::from_static(b"Hello, world!")))
218            .stream_id(StreamID::clamp(5));
219        let mut buf = bytes::BytesMut::new();
220        let mut codec = FrameCodec::default();
221
222        codec
223            .encode(frame.clone(), &mut buf)
224            .expect("no encode error");
225
226        let decoded = codec
227            .decode(&mut buf)
228            .expect("no decode error")
229            .expect("decoded frame");
230
231        assert_eq!(frame, decoded);
232    }
233}