1use std::{
2 io,
3 mem,
4 pin::{
5 pin,
6 Pin,
7 },
8 task::{
9 ready,
10 Context,
11 Poll,
12 },
13};
14
15use bytes::{
16 Buf,
17 BytesMut,
18};
19use proxy_protocol::{
20 ParseError,
21 ProxyHeader,
22};
23use tokio::io::{
24 AsyncRead,
25 AsyncWrite,
26 ReadBuf,
27};
28use tracing::instrument;
29
30const MAX_HEADER_LEN: usize = 536;
33const MIN_HEADER_LEN: usize = 16;
35
36#[derive(Debug)]
37enum ReadState {
38 Reading(Option<ParseError>, BytesMut),
39 Error(proxy_protocol::ParseError, BytesMut),
40 Header(Option<proxy_protocol::ProxyHeader>, BytesMut),
41 None,
42}
43
44impl ReadState {
45 fn new() -> ReadState {
46 ReadState::Reading(None, BytesMut::with_capacity(MAX_HEADER_LEN))
47 }
48
49 fn header(&self) -> Result<Option<&ProxyHeader>, &ParseError> {
50 match self {
51 ReadState::Error(err, _) | ReadState::Reading(Some(err), _) => Err(err),
52 ReadState::None | ReadState::Reading(None, _) => Ok(None),
53 ReadState::Header(hdr, _) => Ok(hdr.as_ref()),
54 }
55 }
56
57 #[instrument(level = "trace", skip(reader))]
60 fn poll_read_header_once(
61 &mut self,
62 cx: &mut Context,
63 mut reader: Pin<&mut impl AsyncRead>,
64 ) -> Poll<io::Result<()>> {
65 loop {
66 let read_state = mem::replace(self, ReadState::None);
67 let (last_err, mut hdr_buf) = match read_state {
68 ReadState::None | ReadState::Header(_, _) | ReadState::Error(_, _) => {
70 *self = read_state;
71 return Poll::Ready(Ok(()));
72 }
73 ReadState::Reading(err, hdr_buf) => (err, hdr_buf),
74 };
75
76 if hdr_buf.len() < MAX_HEADER_LEN {
77 let mut tmp_buf = ReadBuf::uninit(hdr_buf.spare_capacity_mut());
78 let read_res = reader.as_mut().poll_read(cx, &mut tmp_buf);
79 let filled = tmp_buf.filled().len();
81 if filled > 0 {
82 let len = hdr_buf.len();
83 unsafe { hdr_buf.set_len(len + filled) }
88 }
89 match read_res {
90 Poll::Ready(ref res) if res.is_err() || filled == 0 => {
94 *self = match last_err {
95 Some(err) => ReadState::Error(err, hdr_buf),
96 None => ReadState::Header(None, hdr_buf),
97 };
98 return read_res;
99 }
100 Poll::Pending => {
102 *self = ReadState::Reading(last_err, hdr_buf);
103 return read_res;
104 }
105 _ => {}
106 }
107 }
108
109 let mut hdr_view = &*hdr_buf;
112
113 if hdr_view.len() < MIN_HEADER_LEN || matches!(hdr_view.last(), Some(b'\r')) {
118 *self = ReadState::Reading(last_err, hdr_buf);
119 continue;
120 }
121
122 match proxy_protocol::parse(&mut hdr_view) {
123 Ok(hdr) => {
124 hdr_buf.advance(hdr_buf.len() - hdr_view.len());
125 *self = ReadState::Header(Some(hdr), hdr_buf);
126 return Poll::Ready(Ok(()));
127 }
128 Err(ParseError::NotProxyHeader) => {
129 *self = ReadState::Header(None, hdr_buf);
130 return Poll::Ready(Ok(()));
131 }
132
133 Err(err) => {
136 if hdr_buf.len() >= MAX_HEADER_LEN {
138 *self = ReadState::Error(err, hdr_buf);
139 } else {
140 *self = ReadState::Reading(Some(err), hdr_buf);
141 }
142 continue;
143 }
144 }
145 }
146 }
147}
148
149#[derive(Debug)]
150enum WriteState {
151 Writing(BytesMut),
152 None,
153}
154
155impl WriteState {
156 fn new(hdr: proxy_protocol::ProxyHeader) -> Result<WriteState, proxy_protocol::EncodeError> {
157 proxy_protocol::encode(hdr).map(WriteState::Writing)
158 }
159
160 #[instrument(level = "trace", skip(writer))]
163 fn poll_write_header_once(
164 &mut self,
165 cx: &mut Context,
166 mut writer: Pin<&mut impl AsyncWrite>,
167 ) -> Poll<io::Result<()>> {
168 loop {
169 let state = mem::replace(self, WriteState::None);
170 match state {
171 WriteState::None => return Poll::Ready(Ok(())),
172 WriteState::Writing(mut buf) => {
173 let write_res = writer.as_mut().poll_write(cx, &buf);
174 match write_res {
175 Poll::Pending | Poll::Ready(Err(_)) => {
176 *self = WriteState::Writing(buf);
177 ready!(write_res)?;
178 unreachable!(
179 "ready! will return for us on either Pending or Ready(Err)"
180 );
181 }
182 Poll::Ready(Ok(written)) => {
183 buf.advance(written);
184 if !buf.is_empty() {
185 *self = WriteState::Writing(buf);
186 continue;
187 } else {
188 return Ok(()).into();
189 }
190 }
191 }
192 }
193 }
194 }
195 }
196}
197
198#[derive(Debug)]
199#[pin_project::pin_project]
200pub struct Stream<S> {
201 read_state: ReadState,
202 write_state: WriteState,
203 #[pin]
204 inner: S,
205}
206
207impl<S> Stream<S> {
208 pub fn outgoing(stream: S, header: ProxyHeader) -> Result<Self, proxy_protocol::EncodeError> {
209 Ok(Stream {
210 inner: stream,
211 write_state: WriteState::new(header)?,
212 read_state: ReadState::None,
213 })
214 }
215
216 pub fn incoming(stream: S) -> Self {
217 Stream {
218 inner: stream,
219 read_state: ReadState::new(),
220 write_state: WriteState::None,
221 }
222 }
223
224 pub fn disabled(stream: S) -> Self {
225 Stream {
226 inner: stream,
227 read_state: ReadState::None,
228 write_state: WriteState::None,
229 }
230 }
231}
232
233impl<S> Stream<S>
234where
235 S: AsyncRead,
236{
237 #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
238 pub fn poll_proxy_header(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 ) -> Poll<io::Result<Result<Option<&ProxyHeader>, &ParseError>>> {
242 let this = self.project();
243
244 ready!(this.read_state.poll_read_header_once(cx, this.inner))?;
245
246 Ok(this.read_state.header()).into()
247 }
248
249 #[instrument(level = "debug", skip(self))]
250 pub async fn proxy_header(&mut self) -> io::Result<Result<Option<&ProxyHeader>, &ParseError>>
251 where
252 Self: Unpin,
253 {
254 let mut this = Pin::new(self);
255
256 futures::future::poll_fn(|cx| {
257 let this = this.as_mut().project();
258 this.read_state.poll_read_header_once(cx, this.inner)
259 })
260 .await?;
261
262 Ok(this.get_mut().read_state.header())
263 }
264}
265
266impl<S> AsyncRead for Stream<S>
267where
268 S: AsyncRead,
269{
270 #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
271 fn poll_read(
272 self: Pin<&mut Self>,
273 cx: &mut Context<'_>,
274 buf: &mut ReadBuf<'_>,
275 ) -> Poll<io::Result<()>> {
276 let mut this = self.project();
277
278 ready!(this
279 .read_state
280 .poll_read_header_once(cx, this.inner.as_mut()))?;
281
282 match this.read_state {
283 ReadState::Error(_, remainder) | ReadState::Header(_, remainder) => {
284 if !remainder.is_empty() {
285 let available = std::cmp::min(remainder.len(), buf.remaining());
286 buf.put_slice(&remainder.split_to(available));
287 return Poll::Ready(Ok(()));
289 }
290 }
291 ReadState::None => {}
292 _ => unreachable!(),
293 }
294
295 this.inner.poll_read(cx, buf)
296 }
297}
298
299impl<S> AsyncWrite for Stream<S>
300where
301 S: AsyncWrite,
302{
303 #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
304 fn poll_write(
305 self: Pin<&mut Self>,
306 cx: &mut Context<'_>,
307 buf: &[u8],
308 ) -> Poll<Result<usize, io::Error>> {
309 let mut this = self.project();
310
311 ready!(this
312 .write_state
313 .poll_write_header_once(cx, this.inner.as_mut()))?;
314
315 this.inner.poll_write(cx, buf)
316 }
317 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
318 self.project().inner.poll_flush(cx)
319 }
320 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
321 self.project().inner.poll_shutdown(cx)
322 }
323}
324
325#[cfg(feature = "hyper")]
326mod hyper {
327 use ::hyper::rt::{
328 Read as HyperRead,
329 Write as HyperWrite,
330 };
331
332 use super::*;
333
334 impl<S> HyperWrite for Stream<S>
335 where
336 S: AsyncWrite,
337 {
338 #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
339 fn poll_write(
340 self: Pin<&mut Self>,
341 cx: &mut Context<'_>,
342 buf: &[u8],
343 ) -> Poll<Result<usize, io::Error>> {
344 <Self as AsyncWrite>::poll_write(self, cx, buf)
345 }
346 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
347 <Self as AsyncWrite>::poll_flush(self, cx)
348 }
349 fn poll_shutdown(
350 self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 ) -> Poll<Result<(), io::Error>> {
353 <Self as AsyncWrite>::poll_shutdown(self, cx)
354 }
355 }
356
357 impl<S> HyperRead for Stream<S>
358 where
359 S: AsyncRead,
360 {
361 fn poll_read(
362 self: Pin<&mut Self>,
363 cx: &mut Context<'_>,
364 mut buf: ::hyper::rt::ReadBufCursor<'_>,
365 ) -> Poll<Result<(), std::io::Error>> {
366 let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
367 let res = ready!(<Self as AsyncRead>::poll_read(self, cx, &mut tokio_buf));
368 let filled = tokio_buf.filled().len();
369 unsafe { buf.advance(filled) };
370 Poll::Ready(res)
371 }
372 }
373}
374
375#[cfg(test)]
376mod test {
377 use std::{
378 cmp,
379 io,
380 pin::Pin,
381 task::{
382 ready,
383 Context,
384 Poll,
385 },
386 time::Duration,
387 };
388
389 use bytes::{
390 BufMut,
391 BytesMut,
392 };
393 use proxy_protocol::{
394 version2::{
395 self,
396 ProxyCommand,
397 },
398 ProxyHeader,
399 };
400 use tokio::io::{
401 AsyncRead,
402 AsyncReadExt,
403 AsyncWriteExt,
404 ReadBuf,
405 };
406
407 use super::Stream;
408
409 #[pin_project::pin_project]
410 struct ShortReader<S> {
411 #[pin]
412 inner: S,
413 min: usize,
414 max: usize,
415 }
416
417 impl<S> AsyncRead for ShortReader<S>
418 where
419 S: AsyncRead,
420 {
421 fn poll_read(
422 self: Pin<&mut Self>,
423 cx: &mut Context<'_>,
424 buf: &mut ReadBuf<'_>,
425 ) -> Poll<io::Result<()>> {
426 let mut this = self.project();
427 let max_bytes =
428 *this.min + cmp::max(1, rand::random::<usize>() % (*this.max - *this.min));
429 let mut tmp = vec![0; max_bytes];
430 let mut tmp_buf = ReadBuf::new(&mut tmp);
431 let res = ready!(this.inner.as_mut().poll_read(cx, &mut tmp_buf));
432
433 buf.put_slice(tmp_buf.filled());
434
435 res?;
436
437 Poll::Ready(Ok(()))
438 }
439 }
440
441 impl<S> ShortReader<S> {
442 fn new(inner: S, min: usize, max: usize) -> Self {
443 ShortReader { inner, min, max }
444 }
445 }
446
447 const INPUT: &str = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
448 const PARTIAL_INPUT: &str = "PROXY TCP4 192.168.0.1";
449 const FINAL_INPUT: &str = " 192.168.0.11 56324 443\r\n";
450
451 #[test]
454 fn test_proxy_protocol() {
455 let mut buf = BytesMut::from(INPUT);
456
457 assert!(proxy_protocol::parse(&mut buf).is_ok());
458
459 buf = BytesMut::from(PARTIAL_INPUT);
460
461 assert!(proxy_protocol::parse(&mut &*buf).is_err());
462
463 buf.put_slice(FINAL_INPUT.as_bytes());
464
465 assert!(proxy_protocol::parse(&mut &*buf).is_ok());
466 }
467
468 #[tokio::test]
469 #[tracing_test::traced_test]
470 async fn test_header_stream_v2() {
471 let (left, mut right) = tokio::io::duplex(1024);
472
473 let header = ProxyHeader::Version2 {
474 command: ProxyCommand::Proxy,
475 transport_protocol: version2::ProxyTransportProtocol::Stream,
476 addresses: version2::ProxyAddresses::Ipv4 {
477 source: "127.0.0.1:1".parse().unwrap(),
478 destination: "127.0.0.2:2".parse().unwrap(),
479 },
480 };
481
482 let input = proxy_protocol::encode(header).unwrap();
483
484 let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
485
486 tokio::spawn(async move {
488 tokio::time::sleep(Duration::from_millis(50)).await;
489
490 right.write_all(&input).await.expect("write header");
491
492 right
493 .write_all(b"Hello, world!")
494 .await
495 .expect("write hello");
496
497 right.shutdown().await.expect("shutdown");
498 });
499
500 let hdr = proxy_stream
501 .proxy_header()
502 .await
503 .expect("read header")
504 .expect("decode header")
505 .expect("header exists");
506
507 assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
508
509 let mut buf = String::new();
510
511 proxy_stream
512 .read_to_string(&mut buf)
513 .await
514 .expect("read rest");
515
516 assert_eq!(buf, "Hello, world!");
517
518 let hdr = proxy_stream
520 .proxy_header()
521 .await
522 .expect("read header")
523 .expect("decode header")
524 .expect("header exists");
525
526 assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
527 }
528
529 #[tokio::test]
530 #[tracing_test::traced_test]
531 async fn test_header_stream() {
532 let (left, mut right) = tokio::io::duplex(1024);
533
534 let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
535
536 tokio::spawn(async move {
538 tokio::time::sleep(Duration::from_millis(50)).await;
539
540 right
541 .write_all(INPUT.as_bytes())
542 .await
543 .expect("write header");
544
545 right
546 .write_all(b"Hello, world!")
547 .await
548 .expect("write hello");
549
550 right.shutdown().await.expect("shutdown");
551 });
552
553 let hdr = proxy_stream
554 .proxy_header()
555 .await
556 .expect("read header")
557 .expect("decode header")
558 .expect("header exists");
559
560 assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
561
562 let mut buf = String::new();
563
564 proxy_stream
565 .read_to_string(&mut buf)
566 .await
567 .expect("read rest");
568
569 assert_eq!(buf, "Hello, world!");
570
571 let hdr = proxy_stream
573 .proxy_header()
574 .await
575 .expect("read header")
576 .expect("decode header")
577 .expect("header exists");
578
579 assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
580 }
581
582 #[tokio::test]
583 #[tracing_test::traced_test]
584 async fn test_noheader() {
585 let (left, mut right) = tokio::io::duplex(1024);
586
587 let mut proxy_stream = Stream::incoming(left);
588
589 right
590 .write_all(b"Hello, world!")
591 .await
592 .expect("write stream");
593
594 right.shutdown().await.expect("shutdown");
595 drop(right);
596
597 assert!(proxy_stream
598 .proxy_header()
599 .await
600 .unwrap()
601 .unwrap()
602 .is_none());
603
604 let mut buf = String::new();
605
606 proxy_stream
607 .read_to_string(&mut buf)
608 .await
609 .expect("read stream");
610
611 assert_eq!(buf, "Hello, world!");
612 }
613}