ngrok/
proxy_proto.rs

1use std::{
2    io,
3    mem,
4    pin::{
5        pin,
6        Pin,
7    },
8    task::{
9        ready,
10        Context,
11        Poll,
12    },
13};
14
15use bytes::{
16    Buf,
17    BytesMut,
18};
19use proxy_protocol::{
20    ParseError,
21    ProxyHeader,
22};
23use tokio::io::{
24    AsyncRead,
25    AsyncWrite,
26    ReadBuf,
27};
28use tracing::instrument;
29
30// 536 is the smallest possible TCP segment, which both v1 and v2 are guaranteed
31// to fit into.
32const MAX_HEADER_LEN: usize = 536;
33// v2 headers start with at least 16 bytes
34const MIN_HEADER_LEN: usize = 16;
35
36#[derive(Debug)]
37enum ReadState {
38    Reading(Option<ParseError>, BytesMut),
39    Error(proxy_protocol::ParseError, BytesMut),
40    Header(Option<proxy_protocol::ProxyHeader>, BytesMut),
41    None,
42}
43
44impl ReadState {
45    fn new() -> ReadState {
46        ReadState::Reading(None, BytesMut::with_capacity(MAX_HEADER_LEN))
47    }
48
49    fn header(&self) -> Result<Option<&ProxyHeader>, &ParseError> {
50        match self {
51            ReadState::Error(err, _) | ReadState::Reading(Some(err), _) => Err(err),
52            ReadState::None | ReadState::Reading(None, _) => Ok(None),
53            ReadState::Header(hdr, _) => Ok(hdr.as_ref()),
54        }
55    }
56
57    /// Read the header from the stream *once*. Once a header has been read, or
58    /// it's been determined that no header is coming, this will be a no-op.
59    #[instrument(level = "trace", skip(reader))]
60    fn poll_read_header_once(
61        &mut self,
62        cx: &mut Context,
63        mut reader: Pin<&mut impl AsyncRead>,
64    ) -> Poll<io::Result<()>> {
65        loop {
66            let read_state = mem::replace(self, ReadState::None);
67            let (last_err, mut hdr_buf) = match read_state {
68                // End states
69                ReadState::None | ReadState::Header(_, _) | ReadState::Error(_, _) => {
70                    *self = read_state;
71                    return Poll::Ready(Ok(()));
72                }
73                ReadState::Reading(err, hdr_buf) => (err, hdr_buf),
74            };
75
76            if hdr_buf.len() < MAX_HEADER_LEN {
77                let mut tmp_buf = ReadBuf::uninit(hdr_buf.spare_capacity_mut());
78                let read_res = reader.as_mut().poll_read(cx, &mut tmp_buf);
79                // Regardless of error, make sure we track the read bytes
80                let filled = tmp_buf.filled().len();
81                if filled > 0 {
82                    let len = hdr_buf.len();
83                    // Safety: the tmp_buf is backed by the uninitialized
84                    // portion of hdr_buf. Advancing the len to len + filled is
85                    // guaranteed to only cover the bytes initialized by the
86                    // read.
87                    unsafe { hdr_buf.set_len(len + filled) }
88                }
89                match read_res {
90                    // If we hit the end of the stream due to either an EOF or
91                    // an error, set the state to a terminal one and return the
92                    // result.
93                    Poll::Ready(ref res) if res.is_err() || filled == 0 => {
94                        *self = match last_err {
95                            Some(err) => ReadState::Error(err, hdr_buf),
96                            None => ReadState::Header(None, hdr_buf),
97                        };
98                        return read_res;
99                    }
100                    // Pending leaves the last error and buffer unchanged.
101                    Poll::Pending => {
102                        *self = ReadState::Reading(last_err, hdr_buf);
103                        return read_res;
104                    }
105                    _ => {}
106                }
107            }
108
109            // Create a view into the header buffer so that failed parse
110            // attempts don't consume it.
111            let mut hdr_view = &*hdr_buf;
112
113            // Don't try to parse unless we have a minimum number of bytes to
114            // avoid spurious "NotProxyHeader" errors.
115            // Also hack around a bug in the proxy_protocol crate that results
116            // in panics when the input ends in \r without the \n.
117            if hdr_view.len() < MIN_HEADER_LEN || matches!(hdr_view.last(), Some(b'\r')) {
118                *self = ReadState::Reading(last_err, hdr_buf);
119                continue;
120            }
121
122            match proxy_protocol::parse(&mut hdr_view) {
123                Ok(hdr) => {
124                    hdr_buf.advance(hdr_buf.len() - hdr_view.len());
125                    *self = ReadState::Header(Some(hdr), hdr_buf);
126                    return Poll::Ready(Ok(()));
127                }
128                Err(ParseError::NotProxyHeader) => {
129                    *self = ReadState::Header(None, hdr_buf);
130                    return Poll::Ready(Ok(()));
131                }
132
133                // Keep track of the last error - it might not be fatal if we
134                // simply haven't read enough
135                Err(err) => {
136                    // If we've read too much, consider the error fatal.
137                    if hdr_buf.len() >= MAX_HEADER_LEN {
138                        *self = ReadState::Error(err, hdr_buf);
139                    } else {
140                        *self = ReadState::Reading(Some(err), hdr_buf);
141                    }
142                    continue;
143                }
144            }
145        }
146    }
147}
148
149#[derive(Debug)]
150enum WriteState {
151    Writing(BytesMut),
152    None,
153}
154
155impl WriteState {
156    fn new(hdr: proxy_protocol::ProxyHeader) -> Result<WriteState, proxy_protocol::EncodeError> {
157        proxy_protocol::encode(hdr).map(WriteState::Writing)
158    }
159
160    /// Write the header *once*. After its written to the stream, this will be a
161    /// no-op.
162    #[instrument(level = "trace", skip(writer))]
163    fn poll_write_header_once(
164        &mut self,
165        cx: &mut Context,
166        mut writer: Pin<&mut impl AsyncWrite>,
167    ) -> Poll<io::Result<()>> {
168        loop {
169            let state = mem::replace(self, WriteState::None);
170            match state {
171                WriteState::None => return Poll::Ready(Ok(())),
172                WriteState::Writing(mut buf) => {
173                    let write_res = writer.as_mut().poll_write(cx, &buf);
174                    match write_res {
175                        Poll::Pending | Poll::Ready(Err(_)) => {
176                            *self = WriteState::Writing(buf);
177                            ready!(write_res)?;
178                            unreachable!(
179                                "ready! will return for us on either Pending or Ready(Err)"
180                            );
181                        }
182                        Poll::Ready(Ok(written)) => {
183                            buf.advance(written);
184                            if !buf.is_empty() {
185                                *self = WriteState::Writing(buf);
186                                continue;
187                            } else {
188                                return Ok(()).into();
189                            }
190                        }
191                    }
192                }
193            }
194        }
195    }
196}
197
198#[derive(Debug)]
199#[pin_project::pin_project]
200pub struct Stream<S> {
201    read_state: ReadState,
202    write_state: WriteState,
203    #[pin]
204    inner: S,
205}
206
207impl<S> Stream<S> {
208    pub fn outgoing(stream: S, header: ProxyHeader) -> Result<Self, proxy_protocol::EncodeError> {
209        Ok(Stream {
210            inner: stream,
211            write_state: WriteState::new(header)?,
212            read_state: ReadState::None,
213        })
214    }
215
216    pub fn incoming(stream: S) -> Self {
217        Stream {
218            inner: stream,
219            read_state: ReadState::new(),
220            write_state: WriteState::None,
221        }
222    }
223
224    pub fn disabled(stream: S) -> Self {
225        Stream {
226            inner: stream,
227            read_state: ReadState::None,
228            write_state: WriteState::None,
229        }
230    }
231}
232
233impl<S> Stream<S>
234where
235    S: AsyncRead,
236{
237    #[instrument(level = "debug", skip(self))]
238    pub async fn proxy_header(&mut self) -> io::Result<Result<Option<&ProxyHeader>, &ParseError>>
239    where
240        Self: Unpin,
241    {
242        let mut this = Pin::new(self);
243
244        futures::future::poll_fn(|cx| {
245            let this = this.as_mut().project();
246            this.read_state.poll_read_header_once(cx, this.inner)
247        })
248        .await?;
249
250        Ok(this.get_mut().read_state.header())
251    }
252}
253
254impl<S> AsyncRead for Stream<S>
255where
256    S: AsyncRead,
257{
258    #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
259    fn poll_read(
260        self: Pin<&mut Self>,
261        cx: &mut Context<'_>,
262        buf: &mut ReadBuf<'_>,
263    ) -> Poll<io::Result<()>> {
264        let mut this = self.project();
265
266        ready!(this
267            .read_state
268            .poll_read_header_once(cx, this.inner.as_mut()))?;
269
270        match this.read_state {
271            ReadState::Error(_, remainder) | ReadState::Header(_, remainder) => {
272                if !remainder.is_empty() {
273                    let available = std::cmp::min(remainder.len(), buf.remaining());
274                    buf.put_slice(&remainder.split_to(available));
275                    // Make sure Ready is returned regardless of inner's state
276                    return Poll::Ready(Ok(()));
277                }
278            }
279            ReadState::None => {}
280            _ => unreachable!(),
281        }
282
283        this.inner.poll_read(cx, buf)
284    }
285}
286
287impl<S> AsyncWrite for Stream<S>
288where
289    S: AsyncWrite,
290{
291    #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
292    fn poll_write(
293        self: Pin<&mut Self>,
294        cx: &mut Context<'_>,
295        buf: &[u8],
296    ) -> Poll<Result<usize, io::Error>> {
297        let mut this = self.project();
298
299        ready!(this
300            .write_state
301            .poll_write_header_once(cx, this.inner.as_mut()))?;
302
303        this.inner.poll_write(cx, buf)
304    }
305    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
306        self.project().inner.poll_flush(cx)
307    }
308    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
309        self.project().inner.poll_shutdown(cx)
310    }
311}
312
313#[cfg(feature = "hyper")]
314mod hyper {
315    use ::hyper::rt::{
316        Read as HyperRead,
317        Write as HyperWrite,
318    };
319
320    use super::*;
321
322    impl<S> HyperWrite for Stream<S>
323    where
324        S: AsyncWrite,
325    {
326        #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
327        fn poll_write(
328            self: Pin<&mut Self>,
329            cx: &mut Context<'_>,
330            buf: &[u8],
331        ) -> Poll<Result<usize, io::Error>> {
332            <Self as AsyncWrite>::poll_write(self, cx, buf)
333        }
334        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
335            <Self as AsyncWrite>::poll_flush(self, cx)
336        }
337        fn poll_shutdown(
338            self: Pin<&mut Self>,
339            cx: &mut Context<'_>,
340        ) -> Poll<Result<(), io::Error>> {
341            <Self as AsyncWrite>::poll_shutdown(self, cx)
342        }
343    }
344
345    impl<S> HyperRead for Stream<S>
346    where
347        S: AsyncRead,
348    {
349        fn poll_read(
350            self: Pin<&mut Self>,
351            cx: &mut Context<'_>,
352            mut buf: ::hyper::rt::ReadBufCursor<'_>,
353        ) -> Poll<Result<(), std::io::Error>> {
354            let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
355            let res = ready!(<Self as AsyncRead>::poll_read(self, cx, &mut tokio_buf));
356            let filled = tokio_buf.filled().len();
357            unsafe { buf.advance(filled) };
358            Poll::Ready(res)
359        }
360    }
361}
362
363#[cfg(test)]
364mod test {
365    use std::{
366        cmp,
367        io,
368        pin::Pin,
369        task::{
370            ready,
371            Context,
372            Poll,
373        },
374        time::Duration,
375    };
376
377    use bytes::{
378        BufMut,
379        BytesMut,
380    };
381    use proxy_protocol::{
382        version2::{
383            self,
384            ProxyCommand,
385        },
386        ProxyHeader,
387    };
388    use tokio::io::{
389        AsyncRead,
390        AsyncReadExt,
391        AsyncWriteExt,
392        ReadBuf,
393    };
394
395    use super::Stream;
396
397    #[pin_project::pin_project]
398    struct ShortReader<S> {
399        #[pin]
400        inner: S,
401        min: usize,
402        max: usize,
403    }
404
405    impl<S> AsyncRead for ShortReader<S>
406    where
407        S: AsyncRead,
408    {
409        fn poll_read(
410            self: Pin<&mut Self>,
411            cx: &mut Context<'_>,
412            buf: &mut ReadBuf<'_>,
413        ) -> Poll<io::Result<()>> {
414            let mut this = self.project();
415            let max_bytes =
416                *this.min + cmp::max(1, rand::random::<usize>() % (*this.max - *this.min));
417            let mut tmp = vec![0; max_bytes];
418            let mut tmp_buf = ReadBuf::new(&mut tmp);
419            let res = ready!(this.inner.as_mut().poll_read(cx, &mut tmp_buf));
420
421            buf.put_slice(tmp_buf.filled());
422
423            res?;
424
425            Poll::Ready(Ok(()))
426        }
427    }
428
429    impl<S> ShortReader<S> {
430        fn new(inner: S, min: usize, max: usize) -> Self {
431            ShortReader { inner, min, max }
432        }
433    }
434
435    const INPUT: &str = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
436    const PARTIAL_INPUT: &str = "PROXY TCP4 192.168.0.1";
437    const FINAL_INPUT: &str = " 192.168.0.11 56324 443\r\n";
438
439    // Smoke test to ensure that the proxy protocol parser works as expected.
440    // Not actually testing our code.
441    #[test]
442    fn test_proxy_protocol() {
443        let mut buf = BytesMut::from(INPUT);
444
445        assert!(proxy_protocol::parse(&mut buf).is_ok());
446
447        buf = BytesMut::from(PARTIAL_INPUT);
448
449        assert!(proxy_protocol::parse(&mut &*buf).is_err());
450
451        buf.put_slice(FINAL_INPUT.as_bytes());
452
453        assert!(proxy_protocol::parse(&mut &*buf).is_ok());
454    }
455
456    #[tokio::test]
457    #[tracing_test::traced_test]
458    async fn test_header_stream_v2() {
459        let (left, mut right) = tokio::io::duplex(1024);
460
461        let header = ProxyHeader::Version2 {
462            command: ProxyCommand::Proxy,
463            transport_protocol: version2::ProxyTransportProtocol::Stream,
464            addresses: version2::ProxyAddresses::Ipv4 {
465                source: "127.0.0.1:1".parse().unwrap(),
466                destination: "127.0.0.2:2".parse().unwrap(),
467            },
468        };
469
470        let input = proxy_protocol::encode(header).unwrap();
471
472        let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
473
474        // Chunk our writes to ensure that our reader is resilient across split inputs.
475        tokio::spawn(async move {
476            tokio::time::sleep(Duration::from_millis(50)).await;
477
478            right.write_all(&input).await.expect("write header");
479
480            right
481                .write_all(b"Hello, world!")
482                .await
483                .expect("write hello");
484
485            right.shutdown().await.expect("shutdown");
486        });
487
488        let hdr = proxy_stream
489            .proxy_header()
490            .await
491            .expect("read header")
492            .expect("decode header")
493            .expect("header exists");
494
495        assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
496
497        let mut buf = String::new();
498
499        proxy_stream
500            .read_to_string(&mut buf)
501            .await
502            .expect("read rest");
503
504        assert_eq!(buf, "Hello, world!");
505
506        // Get the header again - should be the same.
507        let hdr = proxy_stream
508            .proxy_header()
509            .await
510            .expect("read header")
511            .expect("decode header")
512            .expect("header exists");
513
514        assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
515    }
516
517    #[tokio::test]
518    #[tracing_test::traced_test]
519    async fn test_header_stream() {
520        let (left, mut right) = tokio::io::duplex(1024);
521
522        let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
523
524        // Chunk our writes to ensure that our reader is resilient across split inputs.
525        tokio::spawn(async move {
526            tokio::time::sleep(Duration::from_millis(50)).await;
527
528            right
529                .write_all(INPUT.as_bytes())
530                .await
531                .expect("write header");
532
533            right
534                .write_all(b"Hello, world!")
535                .await
536                .expect("write hello");
537
538            right.shutdown().await.expect("shutdown");
539        });
540
541        let hdr = proxy_stream
542            .proxy_header()
543            .await
544            .expect("read header")
545            .expect("decode header")
546            .expect("header exists");
547
548        assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
549
550        let mut buf = String::new();
551
552        proxy_stream
553            .read_to_string(&mut buf)
554            .await
555            .expect("read rest");
556
557        assert_eq!(buf, "Hello, world!");
558
559        // Get the header again - should be the same.
560        let hdr = proxy_stream
561            .proxy_header()
562            .await
563            .expect("read header")
564            .expect("decode header")
565            .expect("header exists");
566
567        assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
568    }
569
570    #[tokio::test]
571    #[tracing_test::traced_test]
572    async fn test_noheader() {
573        let (left, mut right) = tokio::io::duplex(1024);
574
575        let mut proxy_stream = Stream::incoming(left);
576
577        right
578            .write_all(b"Hello, world!")
579            .await
580            .expect("write stream");
581
582        right.shutdown().await.expect("shutdown");
583        drop(right);
584
585        assert!(proxy_stream
586            .proxy_header()
587            .await
588            .unwrap()
589            .unwrap()
590            .is_none());
591
592        let mut buf = String::new();
593
594        proxy_stream
595            .read_to_string(&mut buf)
596            .await
597            .expect("read stream");
598
599        assert_eq!(buf, "Hello, world!");
600    }
601}