1use std::{
2 io,
3 mem,
4 pin::{
5 pin,
6 Pin,
7 },
8 task::{
9 ready,
10 Context,
11 Poll,
12 },
13};
14
15use bytes::{
16 Buf,
17 BytesMut,
18};
19use proxy_protocol::{
20 ParseError,
21 ProxyHeader,
22};
23use tokio::io::{
24 AsyncRead,
25 AsyncWrite,
26 ReadBuf,
27};
28use tracing::instrument;
29
30const MAX_HEADER_LEN: usize = 536;
33const MIN_HEADER_LEN: usize = 16;
35
36#[derive(Debug)]
37enum ReadState {
38 Reading(Option<ParseError>, BytesMut),
39 Error(proxy_protocol::ParseError, BytesMut),
40 Header(Option<proxy_protocol::ProxyHeader>, BytesMut),
41 None,
42}
43
44impl ReadState {
45 fn new() -> ReadState {
46 ReadState::Reading(None, BytesMut::with_capacity(MAX_HEADER_LEN))
47 }
48
49 fn header(&self) -> Result<Option<&ProxyHeader>, &ParseError> {
50 match self {
51 ReadState::Error(err, _) | ReadState::Reading(Some(err), _) => Err(err),
52 ReadState::None | ReadState::Reading(None, _) => Ok(None),
53 ReadState::Header(hdr, _) => Ok(hdr.as_ref()),
54 }
55 }
56
57 #[instrument(level = "trace", skip(reader))]
60 fn poll_read_header_once(
61 &mut self,
62 cx: &mut Context,
63 mut reader: Pin<&mut impl AsyncRead>,
64 ) -> Poll<io::Result<()>> {
65 loop {
66 let read_state = mem::replace(self, ReadState::None);
67 let (last_err, mut hdr_buf) = match read_state {
68 ReadState::None | ReadState::Header(_, _) | ReadState::Error(_, _) => {
70 *self = read_state;
71 return Poll::Ready(Ok(()));
72 }
73 ReadState::Reading(err, hdr_buf) => (err, hdr_buf),
74 };
75
76 if hdr_buf.len() < MAX_HEADER_LEN {
77 let mut tmp_buf = ReadBuf::uninit(hdr_buf.spare_capacity_mut());
78 let read_res = reader.as_mut().poll_read(cx, &mut tmp_buf);
79 let filled = tmp_buf.filled().len();
81 if filled > 0 {
82 let len = hdr_buf.len();
83 unsafe { hdr_buf.set_len(len + filled) }
88 }
89 match read_res {
90 Poll::Ready(ref res) if res.is_err() || filled == 0 => {
94 *self = match last_err {
95 Some(err) => ReadState::Error(err, hdr_buf),
96 None => ReadState::Header(None, hdr_buf),
97 };
98 return read_res;
99 }
100 Poll::Pending => {
102 *self = ReadState::Reading(last_err, hdr_buf);
103 return read_res;
104 }
105 _ => {}
106 }
107 }
108
109 let mut hdr_view = &*hdr_buf;
112
113 if hdr_view.len() < MIN_HEADER_LEN || matches!(hdr_view.last(), Some(b'\r')) {
118 *self = ReadState::Reading(last_err, hdr_buf);
119 continue;
120 }
121
122 match proxy_protocol::parse(&mut hdr_view) {
123 Ok(hdr) => {
124 hdr_buf.advance(hdr_buf.len() - hdr_view.len());
125 *self = ReadState::Header(Some(hdr), hdr_buf);
126 return Poll::Ready(Ok(()));
127 }
128 Err(ParseError::NotProxyHeader) => {
129 *self = ReadState::Header(None, hdr_buf);
130 return Poll::Ready(Ok(()));
131 }
132
133 Err(err) => {
136 if hdr_buf.len() >= MAX_HEADER_LEN {
138 *self = ReadState::Error(err, hdr_buf);
139 } else {
140 *self = ReadState::Reading(Some(err), hdr_buf);
141 }
142 continue;
143 }
144 }
145 }
146 }
147}
148
149#[derive(Debug)]
150enum WriteState {
151 Writing(BytesMut),
152 None,
153}
154
155impl WriteState {
156 fn new(hdr: proxy_protocol::ProxyHeader) -> Result<WriteState, proxy_protocol::EncodeError> {
157 proxy_protocol::encode(hdr).map(WriteState::Writing)
158 }
159
160 #[instrument(level = "trace", skip(writer))]
163 fn poll_write_header_once(
164 &mut self,
165 cx: &mut Context,
166 mut writer: Pin<&mut impl AsyncWrite>,
167 ) -> Poll<io::Result<()>> {
168 loop {
169 let state = mem::replace(self, WriteState::None);
170 match state {
171 WriteState::None => return Poll::Ready(Ok(())),
172 WriteState::Writing(mut buf) => {
173 let write_res = writer.as_mut().poll_write(cx, &buf);
174 match write_res {
175 Poll::Pending | Poll::Ready(Err(_)) => {
176 *self = WriteState::Writing(buf);
177 ready!(write_res)?;
178 unreachable!(
179 "ready! will return for us on either Pending or Ready(Err)"
180 );
181 }
182 Poll::Ready(Ok(written)) => {
183 buf.advance(written);
184 if !buf.is_empty() {
185 *self = WriteState::Writing(buf);
186 continue;
187 } else {
188 return Ok(()).into();
189 }
190 }
191 }
192 }
193 }
194 }
195 }
196}
197
198#[derive(Debug)]
199#[pin_project::pin_project]
200pub struct Stream<S> {
201 read_state: ReadState,
202 write_state: WriteState,
203 #[pin]
204 inner: S,
205}
206
207impl<S> Stream<S> {
208 pub fn outgoing(stream: S, header: ProxyHeader) -> Result<Self, proxy_protocol::EncodeError> {
209 Ok(Stream {
210 inner: stream,
211 write_state: WriteState::new(header)?,
212 read_state: ReadState::None,
213 })
214 }
215
216 pub fn incoming(stream: S) -> Self {
217 Stream {
218 inner: stream,
219 read_state: ReadState::new(),
220 write_state: WriteState::None,
221 }
222 }
223
224 pub fn disabled(stream: S) -> Self {
225 Stream {
226 inner: stream,
227 read_state: ReadState::None,
228 write_state: WriteState::None,
229 }
230 }
231}
232
233impl<S> Stream<S>
234where
235 S: AsyncRead,
236{
237 #[instrument(level = "debug", skip(self))]
238 pub async fn proxy_header(&mut self) -> io::Result<Result<Option<&ProxyHeader>, &ParseError>>
239 where
240 Self: Unpin,
241 {
242 let mut this = Pin::new(self);
243
244 futures::future::poll_fn(|cx| {
245 let this = this.as_mut().project();
246 this.read_state.poll_read_header_once(cx, this.inner)
247 })
248 .await?;
249
250 Ok(this.get_mut().read_state.header())
251 }
252}
253
254impl<S> AsyncRead for Stream<S>
255where
256 S: AsyncRead,
257{
258 #[instrument(level = "trace", skip(self), fields(read_state = ?self.read_state))]
259 fn poll_read(
260 self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 buf: &mut ReadBuf<'_>,
263 ) -> Poll<io::Result<()>> {
264 let mut this = self.project();
265
266 ready!(this
267 .read_state
268 .poll_read_header_once(cx, this.inner.as_mut()))?;
269
270 match this.read_state {
271 ReadState::Error(_, remainder) | ReadState::Header(_, remainder) => {
272 if !remainder.is_empty() {
273 let available = std::cmp::min(remainder.len(), buf.remaining());
274 buf.put_slice(&remainder.split_to(available));
275 return Poll::Ready(Ok(()));
277 }
278 }
279 ReadState::None => {}
280 _ => unreachable!(),
281 }
282
283 this.inner.poll_read(cx, buf)
284 }
285}
286
287impl<S> AsyncWrite for Stream<S>
288where
289 S: AsyncWrite,
290{
291 #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
292 fn poll_write(
293 self: Pin<&mut Self>,
294 cx: &mut Context<'_>,
295 buf: &[u8],
296 ) -> Poll<Result<usize, io::Error>> {
297 let mut this = self.project();
298
299 ready!(this
300 .write_state
301 .poll_write_header_once(cx, this.inner.as_mut()))?;
302
303 this.inner.poll_write(cx, buf)
304 }
305 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
306 self.project().inner.poll_flush(cx)
307 }
308 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
309 self.project().inner.poll_shutdown(cx)
310 }
311}
312
313#[cfg(feature = "hyper")]
314mod hyper {
315 use ::hyper::rt::{
316 Read as HyperRead,
317 Write as HyperWrite,
318 };
319
320 use super::*;
321
322 impl<S> HyperWrite for Stream<S>
323 where
324 S: AsyncWrite,
325 {
326 #[instrument(level = "trace", skip(self), fields(write_state = ?self.write_state))]
327 fn poll_write(
328 self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 buf: &[u8],
331 ) -> Poll<Result<usize, io::Error>> {
332 <Self as AsyncWrite>::poll_write(self, cx, buf)
333 }
334 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
335 <Self as AsyncWrite>::poll_flush(self, cx)
336 }
337 fn poll_shutdown(
338 self: Pin<&mut Self>,
339 cx: &mut Context<'_>,
340 ) -> Poll<Result<(), io::Error>> {
341 <Self as AsyncWrite>::poll_shutdown(self, cx)
342 }
343 }
344
345 impl<S> HyperRead for Stream<S>
346 where
347 S: AsyncRead,
348 {
349 fn poll_read(
350 self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 mut buf: ::hyper::rt::ReadBufCursor<'_>,
353 ) -> Poll<Result<(), std::io::Error>> {
354 let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
355 let res = ready!(<Self as AsyncRead>::poll_read(self, cx, &mut tokio_buf));
356 let filled = tokio_buf.filled().len();
357 unsafe { buf.advance(filled) };
358 Poll::Ready(res)
359 }
360 }
361}
362
363#[cfg(test)]
364mod test {
365 use std::{
366 cmp,
367 io,
368 pin::Pin,
369 task::{
370 ready,
371 Context,
372 Poll,
373 },
374 time::Duration,
375 };
376
377 use bytes::{
378 BufMut,
379 BytesMut,
380 };
381 use proxy_protocol::{
382 version2::{
383 self,
384 ProxyCommand,
385 },
386 ProxyHeader,
387 };
388 use tokio::io::{
389 AsyncRead,
390 AsyncReadExt,
391 AsyncWriteExt,
392 ReadBuf,
393 };
394
395 use super::Stream;
396
397 #[pin_project::pin_project]
398 struct ShortReader<S> {
399 #[pin]
400 inner: S,
401 min: usize,
402 max: usize,
403 }
404
405 impl<S> AsyncRead for ShortReader<S>
406 where
407 S: AsyncRead,
408 {
409 fn poll_read(
410 self: Pin<&mut Self>,
411 cx: &mut Context<'_>,
412 buf: &mut ReadBuf<'_>,
413 ) -> Poll<io::Result<()>> {
414 let mut this = self.project();
415 let max_bytes =
416 *this.min + cmp::max(1, rand::random::<usize>() % (*this.max - *this.min));
417 let mut tmp = vec![0; max_bytes];
418 let mut tmp_buf = ReadBuf::new(&mut tmp);
419 let res = ready!(this.inner.as_mut().poll_read(cx, &mut tmp_buf));
420
421 buf.put_slice(tmp_buf.filled());
422
423 res?;
424
425 Poll::Ready(Ok(()))
426 }
427 }
428
429 impl<S> ShortReader<S> {
430 fn new(inner: S, min: usize, max: usize) -> Self {
431 ShortReader { inner, min, max }
432 }
433 }
434
435 const INPUT: &str = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n";
436 const PARTIAL_INPUT: &str = "PROXY TCP4 192.168.0.1";
437 const FINAL_INPUT: &str = " 192.168.0.11 56324 443\r\n";
438
439 #[test]
442 fn test_proxy_protocol() {
443 let mut buf = BytesMut::from(INPUT);
444
445 assert!(proxy_protocol::parse(&mut buf).is_ok());
446
447 buf = BytesMut::from(PARTIAL_INPUT);
448
449 assert!(proxy_protocol::parse(&mut &*buf).is_err());
450
451 buf.put_slice(FINAL_INPUT.as_bytes());
452
453 assert!(proxy_protocol::parse(&mut &*buf).is_ok());
454 }
455
456 #[tokio::test]
457 #[tracing_test::traced_test]
458 async fn test_header_stream_v2() {
459 let (left, mut right) = tokio::io::duplex(1024);
460
461 let header = ProxyHeader::Version2 {
462 command: ProxyCommand::Proxy,
463 transport_protocol: version2::ProxyTransportProtocol::Stream,
464 addresses: version2::ProxyAddresses::Ipv4 {
465 source: "127.0.0.1:1".parse().unwrap(),
466 destination: "127.0.0.2:2".parse().unwrap(),
467 },
468 };
469
470 let input = proxy_protocol::encode(header).unwrap();
471
472 let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
473
474 tokio::spawn(async move {
476 tokio::time::sleep(Duration::from_millis(50)).await;
477
478 right.write_all(&input).await.expect("write header");
479
480 right
481 .write_all(b"Hello, world!")
482 .await
483 .expect("write hello");
484
485 right.shutdown().await.expect("shutdown");
486 });
487
488 let hdr = proxy_stream
489 .proxy_header()
490 .await
491 .expect("read header")
492 .expect("decode header")
493 .expect("header exists");
494
495 assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
496
497 let mut buf = String::new();
498
499 proxy_stream
500 .read_to_string(&mut buf)
501 .await
502 .expect("read rest");
503
504 assert_eq!(buf, "Hello, world!");
505
506 let hdr = proxy_stream
508 .proxy_header()
509 .await
510 .expect("read header")
511 .expect("decode header")
512 .expect("header exists");
513
514 assert!(matches!(hdr, ProxyHeader::Version2 { .. }));
515 }
516
517 #[tokio::test]
518 #[tracing_test::traced_test]
519 async fn test_header_stream() {
520 let (left, mut right) = tokio::io::duplex(1024);
521
522 let mut proxy_stream = Stream::incoming(ShortReader::new(left, 2, 5));
523
524 tokio::spawn(async move {
526 tokio::time::sleep(Duration::from_millis(50)).await;
527
528 right
529 .write_all(INPUT.as_bytes())
530 .await
531 .expect("write header");
532
533 right
534 .write_all(b"Hello, world!")
535 .await
536 .expect("write hello");
537
538 right.shutdown().await.expect("shutdown");
539 });
540
541 let hdr = proxy_stream
542 .proxy_header()
543 .await
544 .expect("read header")
545 .expect("decode header")
546 .expect("header exists");
547
548 assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
549
550 let mut buf = String::new();
551
552 proxy_stream
553 .read_to_string(&mut buf)
554 .await
555 .expect("read rest");
556
557 assert_eq!(buf, "Hello, world!");
558
559 let hdr = proxy_stream
561 .proxy_header()
562 .await
563 .expect("read header")
564 .expect("decode header")
565 .expect("header exists");
566
567 assert!(matches!(hdr, ProxyHeader::Version1 { .. }));
568 }
569
570 #[tokio::test]
571 #[tracing_test::traced_test]
572 async fn test_noheader() {
573 let (left, mut right) = tokio::io::duplex(1024);
574
575 let mut proxy_stream = Stream::incoming(left);
576
577 right
578 .write_all(b"Hello, world!")
579 .await
580 .expect("write stream");
581
582 right.shutdown().await.expect("shutdown");
583 drop(right);
584
585 assert!(proxy_stream
586 .proxy_header()
587 .await
588 .unwrap()
589 .unwrap()
590 .is_none());
591
592 let mut buf = String::new();
593
594 proxy_stream
595 .read_to_string(&mut buf)
596 .await
597 .expect("read stream");
598
599 assert_eq!(buf, "Hello, world!");
600 }
601}