use std::{
net::SocketAddr,
pin::Pin,
task::{
Context,
Poll,
},
};
#[cfg(feature = "axum")]
use axum::extract::connect_info::Connected;
#[cfg(feature = "hyper")]
use hyper::rt::{
Read as HyperRead,
Write as HyperWrite,
};
use muxado::typed::TypedStream;
use tokio::io::{
AsyncRead,
AsyncWrite,
};
use crate::{
config::ProxyProto,
internals::proto::{
EdgeType,
ProxyHeader,
},
};
pub(crate) struct ConnInner {
pub(crate) info: Info,
pub(crate) stream: TypedStream,
}
#[derive(Clone)]
pub(crate) struct Info {
pub(crate) header: ProxyHeader,
pub(crate) remote_addr: SocketAddr,
pub(crate) proxy_proto: ProxyProto,
pub(crate) app_protocol: Option<String>,
pub(crate) verify_upstream_tls: bool,
}
impl ConnInfo for Info {
fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
impl EdgeConnInfo for Info {
fn edge_type(&self) -> EdgeType {
self.header.edge_type
}
fn passthrough_tls(&self) -> bool {
self.header.passthrough_tls
}
}
impl EndpointConnInfo for Info {
fn proto(&self) -> &str {
self.header.proto.as_str()
}
}
macro_rules! conn_trait {
($($hyper_bound:tt)*) => {
pub trait Conn: ConnInfo + AsyncRead + AsyncWrite $($hyper_bound)* + Unpin + Send + 'static {}
}
}
#[cfg(not(feature = "hyper"))]
conn_trait!();
#[cfg(feature = "hyper")]
conn_trait! {
+ hyper::rt::Read + hyper::rt::Write
}
pub trait ConnInfo {
fn remote_addr(&self) -> SocketAddr;
}
pub trait EdgeConnInfo {
fn edge_type(&self) -> EdgeType;
fn passthrough_tls(&self) -> bool;
}
pub trait EndpointConnInfo {
fn proto(&self) -> &str;
}
macro_rules! make_conn_type {
(info EdgeConnInfo, $wrapper:tt) => {
impl EdgeConnInfo for $wrapper {
fn edge_type(&self) -> EdgeType {
self.inner.info.edge_type()
}
fn passthrough_tls(&self) -> bool {
self.inner.info.passthrough_tls()
}
}
};
(info EndpointConnInfo, $wrapper:tt) => {
impl EndpointConnInfo for $wrapper {
fn proto(&self) -> &str {
self.inner.info.proto()
}
}
};
($(#[$outer:meta])* $wrapper:ident, $($m:tt),*) => {
$(#[$outer])*
pub struct $wrapper {
pub(crate) inner: ConnInner,
}
impl Conn for $wrapper {}
impl ConnInfo for $wrapper {
fn remote_addr(&self) -> SocketAddr {
self.inner.info.remote_addr()
}
}
impl AsyncRead for $wrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut *self.inner.stream).poll_read(cx, buf)
}
}
#[cfg(feature = "hyper")]
impl HyperRead for $wrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<std::io::Result<()>> {
let mut tokio_buf = tokio::io::ReadBuf::uninit(unsafe{ buf.as_mut() });
let res = std::task::ready!(Pin::new(&mut *self.inner.stream).poll_read(cx, &mut tokio_buf));
let filled = tokio_buf.filled().len();
unsafe { buf.advance(filled) };
Poll::Ready(res)
}
}
#[cfg(feature = "hyper")]
impl HyperWrite for $wrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
}
}
impl AsyncWrite for $wrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut *self.inner.stream).poll_shutdown(cx)
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
#[cfg(feature = "axum")]
impl Connected<&$wrapper> for SocketAddr {
fn connect_info(target: &$wrapper) -> Self {
target.inner.info.remote_addr()
}
}
$(
make_conn_type!(info $m, $wrapper);
)*
};
}
make_conn_type! {
EdgeConn, EdgeConnInfo
}
make_conn_type! {
EndpointConn, EndpointConnInfo
}