1use std::io::Write;
2
3use bytes::{
4 Buf,
5 BufMut,
6 BytesMut,
7};
8use tokio_util::codec::{
9 Decoder,
10 Encoder,
11};
12use tracing::instrument;
13
14use super::{
15 errors::InvalidHeader,
16 frame::*,
17};
18
19#[derive(Default, Debug)]
21pub struct FrameCodec {
22 input_header: Option<Header>,
27}
28
29#[instrument(level = "trace")]
30fn decode_header(mut bs: BytesMut) -> Header {
31 let length_type_flags = bs.get_u32();
32 let length = ((length_type_flags & 0xFFFFFF00) >> 8).try_into().unwrap();
33 let type_flags = length_type_flags as u8;
34
35 Header {
36 length,
37 typ: ((type_flags & 0xF0) >> 4).into(),
38 flags: Flags::from_bits_truncate(type_flags & 0x0F),
39 stream_id: StreamID::mask(bs.get_u32()),
40 }
41}
42
43fn expect_zero_stream_id(header: Header) -> Result<(), InvalidHeader> {
44 if header.stream_id != StreamID::clamp(0) {
45 Err(InvalidHeader::NonZeroStreamID(header.stream_id))
46 } else {
47 Ok(())
48 }
49}
50
51fn expect_non_zero_stream_id(header: Header) -> Result<(), InvalidHeader> {
52 if header.stream_id == StreamID::clamp(0) {
53 Err(InvalidHeader::ZeroStreamID)
54 } else {
55 Ok(())
56 }
57}
58
59fn expect_length(header: Header, length: Length) -> Result<(), InvalidHeader> {
60 if header.length != length {
61 Err(InvalidHeader::Length {
62 expected: length,
63 actual: header.length,
64 })
65 } else {
66 Ok(())
67 }
68}
69
70fn expect_min_length(header: Header, length: Length) -> Result<(), InvalidHeader> {
71 if header.length < length {
72 Err(InvalidHeader::MinLength {
73 expected: length,
74 actual: header.length,
75 })
76 } else {
77 Ok(())
78 }
79}
80
81#[instrument(level = "trace")]
82fn validate_header(header: Header) -> Result<(), InvalidHeader> {
83 match header.typ {
84 HeaderType::Rst => {
85 expect_non_zero_stream_id(header)?;
86 expect_length(header, Length::clamp(4))?;
87 }
88 HeaderType::Data => {
89 expect_non_zero_stream_id(header)?;
90 }
91 HeaderType::WndInc => {
92 expect_non_zero_stream_id(header)?;
93 expect_length(header, Length::clamp(4))?;
94 }
95 HeaderType::GoAway => {
96 expect_zero_stream_id(header)?;
97 expect_min_length(header, Length::clamp(8))?;
98 }
99 HeaderType::Invalid(t) => return Err(InvalidHeader::Type(t)),
100 }
101
102 Ok(())
103}
104
105#[instrument(level = "trace")]
106fn decode_frame(header: Header, mut body: BytesMut) -> Frame {
107 if let Err(error) = validate_header(header) {
108 return Frame {
109 header,
110 body: Body::Invalid {
111 error,
112 body: body.freeze(),
113 },
114 };
115 }
116
117 Frame {
118 header,
119 body: match header.typ {
120 HeaderType::Rst => Body::Rst(ErrorCode::mask(body.get_u32()).into()),
121 HeaderType::Data => Body::Data(body.freeze()),
122 HeaderType::WndInc => Body::WndInc(WndInc::mask(body.get_u32())),
123 HeaderType::GoAway => Body::GoAway {
124 last_stream_id: StreamID::mask(body.get_u32()),
125 error: ErrorCode::mask(body.get_u32()).into(),
126 message: body.freeze(),
127 },
128 HeaderType::Invalid(t) => Body::Invalid {
129 error: InvalidHeader::Type(t),
130 body: body.freeze(),
131 },
132 },
133 }
134}
135
136impl Decoder for FrameCodec {
137 type Item = Frame;
138 type Error = std::io::Error;
139
140 #[instrument(level = "trace")]
141 fn decode(&mut self, b: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
142 let header = if let Some(header) = self.input_header {
143 header
144 } else {
145 if b.len() < 8 {
146 return Ok(None);
147 }
148
149 let header = decode_header(b.split_to(8));
150 self.input_header = Some(header);
151 header
152 };
153
154 if b.len() < *header.length as usize {
155 return Ok(None);
156 }
157
158 let body_bytes = b.split_to(*header.length as usize);
159
160 self.input_header.take();
162
163 Ok(Some(decode_frame(header, body_bytes)))
164 }
165}
166
167#[instrument(level = "trace")]
168fn encode_header(header: Header, buf: &mut BytesMut) {
169 let type_flags: u8 = ((u8::from(header.typ) << 4) & 0xF0) | (header.flags.bits() & 0x0F);
171 let length_type_flags: u32 = (*header.length << 8 & 0xFFFFFF00) | type_flags as u32;
173
174 buf.put_u32(length_type_flags);
175 buf.put_u32(*header.stream_id);
176}
177
178#[instrument(level = "trace")]
179fn encode_body(body: Body, buf: &mut BytesMut) {
180 match body {
181 Body::Rst(err) => buf.put_u32(*ErrorCode::from(err)),
182 Body::Data(data) => buf.writer().write_all(&data).unwrap(),
183 Body::WndInc(inc) => buf.put_u32(*inc),
184 Body::GoAway {
185 last_stream_id,
186 error,
187 message,
188 } => {
189 buf.put_u32(*last_stream_id);
190 buf.put_u32(*ErrorCode::from(error));
191 buf.writer().write_all(&message).unwrap();
192 }
193 Body::Invalid { body, .. } => buf.writer().write_all(&body).unwrap(),
194 }
195}
196
197impl Encoder<Frame> for FrameCodec {
198 type Error = std::io::Error;
199
200 #[instrument(level = "trace")]
201 fn encode(&mut self, frame: Frame, buf: &mut BytesMut) -> Result<(), std::io::Error> {
202 validate_header(frame.header)?;
203 encode_header(frame.header, buf);
204 encode_body(frame.body, buf);
205 Ok(())
206 }
207}
208
209#[cfg(test)]
210mod test {
211 use bytes::Bytes;
212
213 use super::*;
214
215 #[test]
216 fn round_trip() {
217 let frame = Frame::from(Body::Data(Bytes::from_static(b"Hello, world!")))
218 .stream_id(StreamID::clamp(5));
219 let mut buf = bytes::BytesMut::new();
220 let mut codec = FrameCodec::default();
221
222 codec
223 .encode(frame.clone(), &mut buf)
224 .expect("no encode error");
225
226 let decoded = codec
227 .decode(&mut buf)
228 .expect("no decode error")
229 .expect("decoded frame");
230
231 assert_eq!(frame, decoded);
232 }
233}