1use std::{
2 net::SocketAddr,
3 pin::Pin,
4 task::{
5 Context,
6 Poll,
7 },
8};
9
10#[cfg(feature = "axum")]
12use axum::extract::connect_info::Connected;
13#[cfg(feature = "hyper")]
14use hyper::rt::{
15 Read as HyperRead,
16 Write as HyperWrite,
17};
18use muxado::typed::TypedStream;
19use tokio::io::{
20 AsyncRead,
21 AsyncWrite,
22};
23
24use crate::{
25 config::ProxyProto,
26 internals::proto::{
27 EdgeType,
28 ProxyHeader,
29 },
30};
31pub(crate) struct ConnInner {
36 pub(crate) info: Info,
37 pub(crate) stream: TypedStream,
38}
39
40#[derive(Clone)]
41pub(crate) struct Info {
42 pub(crate) header: ProxyHeader,
43 pub(crate) remote_addr: SocketAddr,
44 pub(crate) proxy_proto: ProxyProto,
45 pub(crate) app_protocol: Option<String>,
46 pub(crate) verify_upstream_tls: bool,
47}
48
49impl ConnInfo for Info {
50 fn remote_addr(&self) -> SocketAddr {
51 self.remote_addr
52 }
53}
54
55impl EdgeConnInfo for Info {
56 fn edge_type(&self) -> EdgeType {
57 self.header.edge_type
58 }
59 fn passthrough_tls(&self) -> bool {
60 self.header.passthrough_tls
61 }
62}
63
64impl EndpointConnInfo for Info {
65 fn proto(&self) -> &str {
66 self.header.proto.as_str()
67 }
68}
69
70macro_rules! conn_trait {
74 ($($hyper_bound:tt)*) => {
75 pub trait Conn: ConnInfo + AsyncRead + AsyncWrite $($hyper_bound)* + Unpin + Send + 'static {}
78 }
79}
80
81#[cfg(not(feature = "hyper"))]
82conn_trait!();
83
84#[cfg(feature = "hyper")]
85conn_trait! {
86 + hyper::rt::Read + hyper::rt::Write
87}
88
89pub trait ConnInfo {
91 fn remote_addr(&self) -> SocketAddr;
94}
95
96pub trait EdgeConnInfo {
98 fn edge_type(&self) -> EdgeType;
100 fn passthrough_tls(&self) -> bool;
103}
104
105pub trait EndpointConnInfo {
107 fn proto(&self) -> &str;
109}
110
111macro_rules! make_conn_type {
112 (info EdgeConnInfo, $wrapper:tt) => {
113 impl EdgeConnInfo for $wrapper {
114 fn edge_type(&self) -> EdgeType {
115 self.inner.info.edge_type()
116 }
117 fn passthrough_tls(&self) -> bool {
118 self.inner.info.passthrough_tls()
119 }
120 }
121 };
122 (info EndpointConnInfo, $wrapper:tt) => {
123 impl EndpointConnInfo for $wrapper {
124 fn proto(&self) -> &str {
125 self.inner.info.proto()
126 }
127 }
128 };
129 ($(#[$outer:meta])* $wrapper:ident, $($m:tt),*) => {
130 $(#[$outer])*
131 pub struct $wrapper {
132 pub(crate) inner: ConnInner,
133 }
134
135 impl Conn for $wrapper {}
136
137 impl ConnInfo for $wrapper {
138 fn remote_addr(&self) -> SocketAddr {
139 self.inner.info.remote_addr()
140 }
141 }
142
143 impl AsyncRead for $wrapper {
144 fn poll_read(
145 mut self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 buf: &mut tokio::io::ReadBuf<'_>,
148 ) -> Poll<std::io::Result<()>> {
149 Pin::new(&mut *self.inner.stream).poll_read(cx, buf)
150 }
151 }
152
153 #[cfg(feature = "hyper")]
154 impl HyperRead for $wrapper {
155 fn poll_read(
156 mut self: Pin<&mut Self>,
157 cx: &mut Context<'_>,
158 mut buf: hyper::rt::ReadBufCursor<'_>,
159 ) -> Poll<std::io::Result<()>> {
160 let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe{ buf.as_mut() });
161 let res = std::task::ready!(Pin::new(&mut *self.inner.stream).poll_read(cx, &mut tokio_buf));
162 let filled = tokio_buf.filled().len();
163 unsafe { buf.advance(filled) };
164 Poll::Ready(res)
165 }
166 }
167
168 #[cfg(feature = "hyper")]
169 impl HyperWrite for $wrapper {
170 fn poll_write(
171 mut self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &[u8],
174 ) -> Poll<Result<usize, std::io::Error>> {
175 Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
176 }
177 fn poll_flush(
178 mut self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 ) -> Poll<Result<(), std::io::Error>> {
181 Pin::new(&mut *self.inner.stream).poll_flush(cx)
182 }
183 fn poll_shutdown(
184 mut self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 ) -> Poll<Result<(), std::io::Error>> {
187 Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
188 }
189 }
190
191 impl AsyncWrite for $wrapper {
192 fn poll_write(
193 mut self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 buf: &[u8],
196 ) -> Poll<Result<usize, std::io::Error>> {
197 Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
198 }
199 fn poll_flush(
200 mut self: Pin<&mut Self>,
201 cx: &mut Context<'_>,
202 ) -> Poll<Result<(), std::io::Error>> {
203 Pin::new(&mut *self.inner.stream).poll_flush(cx)
204 }
205 fn poll_shutdown(
206 mut self: Pin<&mut Self>,
207 cx: &mut Context<'_>,
208 ) -> Poll<Result<(), std::io::Error>> {
209 Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
210 }
211 }
212
213 #[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
214 #[cfg(feature = "axum")]
215 impl Connected<&$wrapper> for SocketAddr {
216 fn connect_info(target: &$wrapper) -> Self {
217 target.inner.info.remote_addr()
218 }
219 }
220
221 $(
222 make_conn_type!(info $m, $wrapper);
223 )*
224 };
225}
226
227make_conn_type! {
228 EdgeConn, EdgeConnInfo
230}
231
232make_conn_type! {
233 EndpointConn, EndpointConnInfo
235}