ngrok/
proxy_proto.rs

1use std::{
2    io,
3    mem,
4    pin::{
5        pin,
6        Pin,
7    },
8    task::{
9        ready,
10        Context,
11        Poll,
12    },
13};
14
15use bytes::{
16    Buf,
17    BytesMut,
18};
19use proxy_protocol::{
20    ParseError,
21    ProxyHeader,
22};
23use tokio::io::{
24    AsyncRead,
25    AsyncWrite,
26    ReadBuf,
27};
28use tracing::instrument;
29
30// 536 is the smallest possible TCP segment, which both v1 and v2 are guaranteed
31// to fit into.
32const MAX_HEADER_LEN: usize = 536;
33// v2 headers start with at least 16 bytes
34const MIN_HEADER_LEN: usize = 16;
35
36#[derive(Debug)]
37enum ReadState {
38    Reading(Option<ParseError>, BytesMut),
39    Error(proxy_protocol::ParseError, BytesMut),
40    Header(Option<proxy_protocol::ProxyHeader>, BytesMut),
41    None,
42}
43
44impl ReadState {
45    fn new() -> ReadState {
46        ReadState::Reading(None, BytesMut::with_capacity(MAX_HEADER_LEN))
47    }
48
49    fn header(&self) -> Result<Option<&ProxyHeader>, &ParseError> {
50        match self {
51            ReadState::Error(err, _) | ReadState::Reading(Some(err), _) => Err(err),
52            ReadState::None | ReadState::Reading(None, _) => Ok(None),
53            ReadState::Header(hdr, _) => Ok(hdr.as_ref()),
54        }
55    }
56
57    /// Read the header from the stream *once*. Once a header has been read, or
58    /// it's been determined that no header is coming, this will be a no-op.
59    #[instrument(level = "trace", skip(reader))]
60    fn poll_read_header_once(
61        &mut self,
62        cx: &mut Context,
63        mut reader: Pin<&mut impl AsyncRead>,
64    ) -> Poll<io::Result<()>> {
65        loop {
66            let read_state = mem::replace(self, ReadState::None);
67            let (last_err, mut hdr_buf) = match read_state {
68                // End states
69                ReadState::None | ReadState::Header(_, _) | ReadState::Error(_, _) => {
70                    *self = read_state;
71                    return Poll::Ready(Ok(()));
72                }
73                ReadState::Reading(err, hdr_buf) => (err, hdr_buf),
74            };
75
76            if hdr_buf.len() < MAX_HEADER_LEN {
77                let mut tmp_buf = ReadBuf::uninit(hdr_buf.spare_capacity_mut());
78                let read_res = reader.as_mut().poll_read(cx, &mut tmp_buf);
79                // Regardless of error, make sure we track the read bytes
80                let filled = tmp_buf.filled().len();
81                if filled > 0 {
82                    let len = hdr_buf.len();
83                    // Safety: the tmp_buf is backed by the uninitialized
84                    // portion of hdr_buf. Advancing the len to len + filled is
85                    // guaranteed to only cover the bytes initialized by the
86                    // read.
87                    unsafe { hdr_buf.set_len(len + filled) }
88                }
89                match read_res {
90                    // If we hit the end of the stream due to either an EOF or
91                    // an error, set the state to a terminal one and return the
92                    // result.
93                    Poll::Ready(ref res) if res.is_err() || filled == 0 => {
94                        *self = match last_err {
95                            Some(err) => ReadState::Error(err, hdr_buf),
96                            None => ReadState::Header(None, hdr_buf),
97                        };
98                        return read_res;
99                    }
100                    // Pending leaves the last error and buffer unchanged.
101                    Poll::Pending => {
102                        *self = ReadState::Reading(last_err, hdr_buf);
103                        return read_res;
104                    }
105                    _ => {}
106                }
107            }
108
109            // Create a view into the header buffer so that failed parse
110            // attempts don't consume it.
111            let mut hdr_view = &*hdr_buf;
112
113            // Don't try to parse unless we have a minimum number of bytes to
114            // avoid spurious "NotProxyHeader" errors.
115            // Also hack around a bug in the proxy_protocol crate that results
116            // in panics when the input ends in \r without the \n.
117            if hdr_view.len() < MIN_HEADER_LEN || matches!(hdr_view.last(), Some(b'\r')) {
118                *self = ReadState::Reading(last_err, hdr_buf);
119                continue;
120            }
121
122            match proxy_protocol::parse(&mut hdr_view) {
123                Ok(hdr) => {
124                    hdr_buf.advance(hdr_buf.len() - hdr_view.len());
125                    *self = ReadState::Header(Some(hdr), hdr_buf);
126                    return Poll::Ready(Ok(()));
127                }
128                Err(ParseError::NotProxyHeader) => {
129                    *self = ReadState::Header(None, hdr_buf);
130                    return Poll::Ready(Ok(()));
131                }
132
133                // Keep track of the last error - it might not be fatal if we
134                // simply haven't read enough
135                Err(err) => {
136                    // If we've read too much, consider the error fatal.
137                    if hdr_buf.len() >= MAX_HEADER_LEN {
138                        *self = ReadState::Error(err, hdr_buf);
139                    } else {
140                        *self = ReadState::Reading(Some(err), hdr_buf);
141                    }
142                    continue;
143                }
144            }
145        }
146    }
147}
148
149#[derive(Debug)]
150enum WriteState {
151    Writing(BytesMut),
152    None,
153}
154
155impl WriteState {
156    fn new(hdr: proxy_protocol::ProxyHeader) -> Result<WriteState, proxy_protocol::EncodeError> {
157        proxy_protocol::encode(hdr).map(WriteState::Writing)
158    }
159
160    /// Write the header *once*. After its written to the stream, this will be a
161    /// no-op.
162    #[instrument(level = "trace", skip(writer))]
163    fn poll_write_header_once(
164        &mut self,
165        cx: &mut Context,
166        mut writer: Pin<&mut impl AsyncWrite>,
167    ) -> Poll<io::Result<()>> {
168        loop {
169            let state = mem::replace(self, WriteState::None);
170            match state {
171                WriteState::None => return Poll::Ready(Ok(())),
172                WriteState::Writing(mut buf) => {
173                    let write_res = writer.as_mut().poll_write(cx, &buf);
174                    match write_res {
175                        Poll::Pending | Poll::Ready(Err(_)) => {
176                            *self = WriteState::Writing(buf);
177                            ready!(write_res)?;
178                            unreachable!(
179                                "ready! will return for us on either Pending or Ready(Err)"
180                            );
181                        }
182                        Poll::Ready(Ok(written)) => {
183                            buf.advance(written);
184                            if !buf.is_empty() {
185                                *self = WriteState::Writing(buf);
186                                continue;
187                            } else {
188                                return Ok(()).into();
189                            }
190                        }
191                    }
192                }
193            }
194        }
195    }
196}
197
198#[derive(Debug)]
199#[pin_project::pin_project]
200pub struct Stream<S> {
201    read_state: ReadState,
202    write_state: WriteState,
203    #[pin]
204    inner: S,
205}
206
207impl<S> Stream<S> {
208    pub fn outgoing(stream: S, header: ProxyHeader) -> Result<Self, proxy_protocol::EncodeError> {
209        Ok(Stream {
210            inner: stream,
211            write_state: WriteState::new(header)?,
212            read_state: ReadState::None,
213        })
214    }
215
216    pub fn incoming(stream: S) -> Self {
217        Stream {
218            inner: stream,
219            read_state: ReadState::new(),
220            write_state: WriteState::None,
221        }
222    }
223
224    pub fn disabled(stream: S) -> Self {
225        Stream {
226            inner: stream,
227            read_state: ReadState::None,
228            write_state: WriteState::None,
229        }
230    }
231}
232
233impl<S> Stream<S>
234where
235    S: AsyncRead,
236{
237    #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
238    pub fn poll_proxy_header(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241    ) -> Poll<io::Result<Result<Option<&ProxyHeader>, &ParseError>>> {
242        let this = self.project();
243
244        ready!(this.read_state.poll_read_header_once(cx, this.inner))?;
245
246        Ok(this.read_state.header()).into()
247    }
248
249    #[instrument(level = "debug", skip(self))]
250    pub async fn proxy_header(&mut self) -> io::Result<Result<Option<&ProxyHeader>, &ParseError>>
251    where
252        Self: Unpin,
253    {
254        let mut this = Pin::new(self);
255
256        futures::future::poll_fn(|cx| {
257            let this = this.as_mut().project();
258            this.read_state.poll_read_header_once(cx, this.inner)
259        })
260        .await?;
261
262        Ok(this.get_mut().read_state.header())
263    }
264}
265
266impl<S> AsyncRead for Stream<S>
267where
268    S: AsyncRead,
269{
270    #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
271    fn poll_read(
272        self: Pin<&mut Self>,
273        cx: &mut Context<'_>,
274        buf: &mut ReadBuf<'_>,
275    ) -> Poll<io::Result<()>> {
276        let mut this = self.project();
277
278        ready!(this
279            .read_state
280            .poll_read_header_once(cx, this.inner.as_mut()))?;
281
282        match this.read_state {
283            ReadState::Error(_, remainder) | ReadState::Header(_, remainder) => {
284                if !remainder.is_empty() {
285                    let available = std::cmp::min(remainder.len(), buf.remaining());
286                    buf.put_slice(&remainder.split_to(available));
287                    // Make sure Ready is returned regardless of inner's state
288                    return Poll::Ready(Ok(()));
289                }
290            }
291            ReadState::None => {}
292            _ => unreachable!(),
293        }
294
295        this.inner.poll_read(cx, buf)
296    }
297}
298
299impl<S> AsyncWrite for Stream<S>
300where
301    S: AsyncWrite,
302{
303    #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
304    fn poll_write(
305        self: Pin<&mut Self>,
306        cx: &mut Context<'_>,
307        buf: &[u8],
308    ) -> Poll<Result<usize, io::Error>> {
309        let mut this = self.project();
310
311        ready!(this
312            .write_state
313            .poll_write_header_once(cx, this.inner.as_mut()))?;
314
315        this.inner.poll_write(cx, buf)
316    }
317    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
318        self.project().inner.poll_flush(cx)
319    }
320    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
321        self.project().inner.poll_shutdown(cx)
322    }
323}
324
325#[cfg(feature = "hyper")]
326mod hyper {
327    use ::hyper::rt::{
328        Read as HyperRead,
329        Write as HyperWrite,
330    };
331
332    use super::*;
333
334    impl<S> HyperWrite for Stream<S>
335    where
336        S: AsyncWrite,
337    {
338        #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
339        fn poll_write(
340            self: Pin<&mut Self>,
341            cx: &mut Context<'_>,
342            buf: &[u8],
343        ) -> Poll<Result<usize, io::Error>> {
344            <Self as AsyncWrite>::poll_write(self, cx, buf)
345        }
346        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
347            <Self as AsyncWrite>::poll_flush(self, cx)
348        }
349        fn poll_shutdown(
350            self: Pin<&mut Self>,
351            cx: &mut Context<'_>,
352        ) -> Poll<Result<(), io::Error>> {
353            <Self as AsyncWrite>::poll_shutdown(self, cx)
354        }
355    }
356
357    impl<S> HyperRead for Stream<S>
358    where
359        S: AsyncRead,
360    {
361        fn poll_read(
362            self: Pin<&mut Self>,
363            cx: &mut Context<'_>,
364            mut buf: ::hyper::rt::ReadBufCursor<'_>,
365        ) -> Poll<Result<(), std::io::Error>> {
366            let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
367            let res = ready!(<Self as AsyncRead>::poll_read(self, cx, &mut tokio_buf));
368            let filled = tokio_buf.filled().len();
369            unsafe { buf.advance(filled) };
370            Poll::Ready(res)
371        }
372    }
373}
374
375#[cfg(test)]
376mod test {
377    use std::{
378        cmp,
379        io,
380        pin::Pin,
381        task::{
382            ready,
383            Context,
384            Poll,
385        },
386        time::Duration,
387    };
388
389    use bytes::{
390        BufMut,
391        BytesMut,
392    };
393    use proxy_protocol::{
394        version2::{
395            self,
396            ProxyCommand,
397        },
398        ProxyHeader,
399    };
400    use tokio::io::{
401        AsyncRead,
402        AsyncReadExt,
403        AsyncWriteExt,
404        ReadBuf,
405    };
406
407    use super::Stream;
408
409    #[pin_project::pin_project]
410    struct ShortReader<S> {
411        #[pin]
412        inner: S,
413        min: usize,
414        max: usize,
415    }
416
417    impl<S> AsyncRead for ShortReader<S>
418    where
419        S: AsyncRead,
420    {
421        fn poll_read(
422            self: Pin<&mut Self>,
423            cx: &mut Context<'_>,
424            buf: &mut ReadBuf<'_>,
425        ) -> Poll<io::Result<()>> {
426            let mut this = self.project();
427            let max_bytes =
428                *this.min + cmp::max(1, rand::random::<usize>() % (*this.max - *this.min));
429            let mut tmp = vec![0; max_bytes];
430            let mut tmp_buf = ReadBuf::new(&mut tmp);
431            let res = ready!(this.inner.as_mut().poll_read(cx, &mut tmp_buf));
432
433            buf.put_slice(tmp_buf.filled());
434
435            res?;
436
437            Poll::Ready(Ok(()))
438        }
439    }
440
441    impl<S> ShortReader<S> {
442        fn new(inner: S, min: usize, max: usize) -> Self {
443            ShortReader { inner, min, max }
444        }
445    }
446
447    const INPUT: &str = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
448    const PARTIAL_INPUT: &str = "PROXY TCP4 192.168.0.1";
449    const FINAL_INPUT: &str = " 192.168.0.11 56324 443\r\n";
450
451    // Smoke test to ensure that the proxy protocol parser works as expected.
452    // Not actually testing our code.
453    #[test]
454    fn test_proxy_protocol() {
455        let mut buf = BytesMut::from(INPUT);
456
457        assert!(proxy_protocol::parse(&mut buf).is_ok());
458
459        buf = BytesMut::from(PARTIAL_INPUT);
460
461        assert!(proxy_protocol::parse(&mut &*buf).is_err());
462
463        buf.put_slice(FINAL_INPUT.as_bytes());
464
465        assert!(proxy_protocol::parse(&mut &*buf).is_ok());
466    }
467
468    #[tokio::test]
469    #[tracing_test::traced_test]
470    async fn test_header_stream_v2() {
471        let (left, mut right) = tokio::io::duplex(1024);
472
473        let header = ProxyHeader::Version2 {
474            command: ProxyCommand::Proxy,
475            transport_protocol: version2::ProxyTransportProtocol::Stream,
476            addresses: version2::ProxyAddresses::Ipv4 {
477                source: "127.0.0.1:1".parse().unwrap(),
478                destination: "127.0.0.2:2".parse().unwrap(),
479            },
480        };
481
482        let input = proxy_protocol::encode(header).unwrap();
483
484        let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
485
486        // Chunk our writes to ensure that our reader is resilient across split inputs.
487        tokio::spawn(async move {
488            tokio::time::sleep(Duration::from_millis(50)).await;
489
490            right.write_all(&input).await.expect("write header");
491
492            right
493                .write_all(b"Hello, world!")
494                .await
495                .expect("write hello");
496
497            right.shutdown().await.expect("shutdown");
498        });
499
500        let hdr = proxy_stream
501            .proxy_header()
502            .await
503            .expect("read header")
504            .expect("decode header")
505            .expect("header exists");
506
507        assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
508
509        let mut buf = String::new();
510
511        proxy_stream
512            .read_to_string(&mut buf)
513            .await
514            .expect("read rest");
515
516        assert_eq!(buf, "Hello, world!");
517
518        // Get the header again - should be the same.
519        let hdr = proxy_stream
520            .proxy_header()
521            .await
522            .expect("read header")
523            .expect("decode header")
524            .expect("header exists");
525
526        assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
527    }
528
529    #[tokio::test]
530    #[tracing_test::traced_test]
531    async fn test_header_stream() {
532        let (left, mut right) = tokio::io::duplex(1024);
533
534        let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
535
536        // Chunk our writes to ensure that our reader is resilient across split inputs.
537        tokio::spawn(async move {
538            tokio::time::sleep(Duration::from_millis(50)).await;
539
540            right
541                .write_all(INPUT.as_bytes())
542                .await
543                .expect("write header");
544
545            right
546                .write_all(b"Hello, world!")
547                .await
548                .expect("write hello");
549
550            right.shutdown().await.expect("shutdown");
551        });
552
553        let hdr = proxy_stream
554            .proxy_header()
555            .await
556            .expect("read header")
557            .expect("decode header")
558            .expect("header exists");
559
560        assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
561
562        let mut buf = String::new();
563
564        proxy_stream
565            .read_to_string(&mut buf)
566            .await
567            .expect("read rest");
568
569        assert_eq!(buf, "Hello, world!");
570
571        // Get the header again - should be the same.
572        let hdr = proxy_stream
573            .proxy_header()
574            .await
575            .expect("read header")
576            .expect("decode header")
577            .expect("header exists");
578
579        assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
580    }
581
582    #[tokio::test]
583    #[tracing_test::traced_test]
584    async fn test_noheader() {
585        let (left, mut right) = tokio::io::duplex(1024);
586
587        let mut proxy_stream = Stream::incoming(left);
588
589        right
590            .write_all(b"Hello, world!")
591            .await
592            .expect("write stream");
593
594        right.shutdown().await.expect("shutdown");
595        drop(right);
596
597        assert!(proxy_stream
598            .proxy_header()
599            .await
600            .unwrap()
601            .unwrap()
602            .is_none());
603
604        let mut buf = String::new();
605
606        proxy_stream
607            .read_to_string(&mut buf)
608            .await
609            .expect("read stream");
610
611        assert_eq!(buf, "Hello, world!");
612    }
613}