ngrok/
conn.rs

1use std::{
2    net::SocketAddr,
3    pin::Pin,
4    task::{
5        Context,
6        Poll,
7    },
8};
9
10// Support for axum's connection info trait.
11#[cfg(feature = "axum")]
12use axum::extract::connect_info::Connected;
13#[cfg(feature = "hyper")]
14use hyper::rt::{
15    Read as HyperRead,
16    Write as HyperWrite,
17};
18use muxado::typed::TypedStream;
19use tokio::io::{
20    AsyncRead,
21    AsyncWrite,
22};
23
24use crate::{
25    config::ProxyProto,
26    internals::proto::{
27        EdgeType,
28        ProxyHeader,
29    },
30};
31/// A connection from an ngrok tunnel.
32///
33/// This implements [AsyncRead]/[AsyncWrite], as well as providing access to the
34/// address from which the connection to the ngrok edge originated.
35pub(crate) struct ConnInner {
36    pub(crate) info: Info,
37    pub(crate) stream: TypedStream,
38}
39
40#[derive(Clone)]
41pub(crate) struct Info {
42    pub(crate) header: ProxyHeader,
43    pub(crate) remote_addr: SocketAddr,
44    pub(crate) proxy_proto: ProxyProto,
45    pub(crate) app_protocol: Option<String>,
46    pub(crate) verify_upstream_tls: bool,
47}
48
49impl ConnInfo for Info {
50    fn remote_addr(&self) -> SocketAddr {
51        self.remote_addr
52    }
53}
54
55impl EdgeConnInfo for Info {
56    fn edge_type(&self) -> EdgeType {
57        self.header.edge_type
58    }
59    fn passthrough_tls(&self) -> bool {
60        self.header.passthrough_tls
61    }
62}
63
64impl EndpointConnInfo for Info {
65    fn proto(&self) -> &str {
66        self.header.proto.as_str()
67    }
68}
69
70// This codgen indirect is required to make the hyper io trait bounds
71// dependent on the hyper feature. You can't put a #[cfg] on a single bound, so
72// we're putting the whole trait def in a macro. Gross, but gets the job done.
73macro_rules! conn_trait {
74    ($($hyper_bound:tt)*) => {
75		/// An incoming connection over an ngrok tunnel.
76		/// Effectively a trait alias for async read+write, plus connection info.
77		pub trait Conn: ConnInfo + AsyncRead + AsyncWrite $($hyper_bound)* + Unpin + Send + 'static {}
78	}
79}
80
81#[cfg(not(feature = "hyper"))]
82conn_trait!();
83
84#[cfg(feature = "hyper")]
85conn_trait! {
86    + hyper::rt::Read + hyper::rt::Write
87}
88
89/// Information common to all ngrok connections.
90pub trait ConnInfo {
91    /// Returns the client address that initiated the connection to the ngrok
92    /// edge.
93    fn remote_addr(&self) -> SocketAddr;
94}
95
96/// Information about connections via ngrok edges.
97pub trait EdgeConnInfo {
98    /// Returns the edge type for this connection.
99    fn edge_type(&self) -> EdgeType;
100    /// Returns whether the connection includes the tls handshake and encrypted
101    /// stream.
102    fn passthrough_tls(&self) -> bool;
103}
104
105/// Information about connections via ngrok endpoints.
106pub trait EndpointConnInfo {
107    /// Returns the endpoint protocol.
108    fn proto(&self) -> &str;
109}
110
111macro_rules! make_conn_type {
112	(info EdgeConnInfo, $wrapper:tt) => {
113		impl EdgeConnInfo for $wrapper {
114			fn edge_type(&self) -> EdgeType {
115				self.inner.info.edge_type()
116			}
117			fn passthrough_tls(&self) -> bool {
118				self.inner.info.passthrough_tls()
119			}
120		}
121	};
122	(info EndpointConnInfo, $wrapper:tt) => {
123		impl EndpointConnInfo for $wrapper {
124			fn proto(&self) -> &str {
125				self.inner.info.proto()
126			}
127		}
128	};
129    ($(#[$outer:meta])* $wrapper:ident, $($m:tt),*) => {
130        $(#[$outer])*
131        pub struct $wrapper {
132            pub(crate) inner: ConnInner,
133        }
134
135        impl Conn for $wrapper {}
136
137        impl ConnInfo for $wrapper {
138			fn remote_addr(&self) -> SocketAddr {
139				self.inner.info.remote_addr()
140			}
141        }
142
143		impl AsyncRead for $wrapper {
144			fn poll_read(
145				mut self: Pin<&mut Self>,
146				cx: &mut Context<'_>,
147				buf: &mut tokio::io::ReadBuf<'_>,
148			) -> Poll<std::io::Result<()>> {
149				Pin::new(&mut *self.inner.stream).poll_read(cx, buf)
150			}
151		}
152
153		#[cfg(feature = "hyper")]
154		impl HyperRead for $wrapper {
155			fn poll_read(
156				mut self: Pin<&mut Self>,
157				cx: &mut Context<'_>,
158				mut buf: hyper::rt::ReadBufCursor<'_>,
159			) -> Poll<std::io::Result<()>> {
160				let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe{ buf.as_mut() });
161				let res = std::task::ready!(Pin::new(&mut *self.inner.stream).poll_read(cx, &mut tokio_buf));
162				let filled = tokio_buf.filled().len();
163				unsafe { buf.advance(filled) };
164				Poll::Ready(res)
165			}
166		}
167
168		#[cfg(feature = "hyper")]
169		impl HyperWrite for $wrapper {
170			fn poll_write(
171				mut self: Pin<&mut Self>,
172				cx: &mut Context<'_>,
173				buf: &[u8],
174			) -> Poll<Result<usize, std::io::Error>> {
175				Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
176			}
177			fn poll_flush(
178				mut self: Pin<&mut Self>,
179				cx: &mut Context<'_>,
180			) -> Poll<Result<(), std::io::Error>> {
181				Pin::new(&mut *self.inner.stream).poll_flush(cx)
182			}
183			fn poll_shutdown(
184				mut self: Pin<&mut Self>,
185				cx: &mut Context<'_>,
186			) -> Poll<Result<(), std::io::Error>> {
187				Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
188			}
189		}
190
191		impl AsyncWrite for $wrapper {
192			fn poll_write(
193				mut self: Pin<&mut Self>,
194				cx: &mut Context<'_>,
195				buf: &[u8],
196			) -> Poll<Result<usize, std::io::Error>> {
197				Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
198			}
199			fn poll_flush(
200				mut self: Pin<&mut Self>,
201				cx: &mut Context<'_>,
202			) -> Poll<Result<(), std::io::Error>> {
203				Pin::new(&mut *self.inner.stream).poll_flush(cx)
204			}
205			fn poll_shutdown(
206				mut self: Pin<&mut Self>,
207				cx: &mut Context<'_>,
208			) -> Poll<Result<(), std::io::Error>> {
209				Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
210			}
211		}
212
213		#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
214		#[cfg(feature = "axum")]
215		impl Connected<&$wrapper> for SocketAddr {
216			fn connect_info(target: &$wrapper) -> Self {
217				target.inner.info.remote_addr()
218			}
219		}
220
221		$(
222			make_conn_type!(info $m, $wrapper);
223		)*
224    };
225}
226
227make_conn_type! {
228    /// A connection via an ngrok Edge.
229    EdgeConn, EdgeConnInfo
230}
231
232make_conn_type! {
233    /// A connection via an ngrok Endpoint.
234    EndpointConn, EndpointConnInfo
235}