1use std::{
2 collections::HashMap,
3 io,
4 sync::Arc,
5};
6#[cfg(feature = "hyper")]
7use std::{
8 convert::Infallible,
9 fmt,
10};
11
12use async_trait::async_trait;
13use bitflags::bitflags;
14use futures::stream::TryStreamExt;
15use futures_rustls::rustls::{
16 self,
17 pki_types,
18 ClientConfig,
19};
20#[cfg(feature = "hyper")]
21use hyper::{
22 server::conn::http1,
23 service::service_fn,
24 Response,
25 StatusCode,
26};
27use once_cell::sync::Lazy;
28use proxy_protocol::ProxyHeader;
29use tokio::{
30 io::copy_bidirectional,
31 net::TcpStream,
32 task::JoinHandle,
33};
34use tokio_util::compat::{
35 FuturesAsyncReadCompatExt,
36 TokioAsyncReadCompatExt,
37};
38#[cfg(feature = "hyper")]
39use tracing::debug;
40use tracing::{
41 field,
42 warn,
43 Instrument,
44 Span,
45};
46use url::Url;
47
48use crate::{
49 prelude::*,
50 proxy_proto,
51 session::IoStream,
52 EdgeConn,
53 EndpointConn,
54};
55
56#[allow(deprecated)]
57#[async_trait]
58impl<T> TunnelExt for T
59where
60 T: Tunnel + Send,
61 <T as Tunnel>::Conn: ConnExt,
62{
63 async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
64 forward_tunnel(self, url).await
65 }
66}
67
68#[async_trait]
70#[deprecated = "superceded by the `listen_and_forward` builder method"]
71pub trait TunnelExt: Tunnel + Send {
72 async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
87}
88
89pub(crate) trait ConnExt {
90 fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
91}
92
93#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
94pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
95where
96 T: Tunnel + 'static + ?Sized,
97 <T as Tunnel>::Conn: ConnExt,
98{
99 loop {
100 let tunnel_conn = if let Some(conn) = tun
101 .try_next()
102 .await
103 .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
104 {
105 conn
106 } else {
107 return Ok(());
108 };
109
110 tunnel_conn.forward_to(&url);
111 }
112}
113
114impl ConnExt for EdgeConn {
115 fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
116 let url = url.clone();
117 tokio::spawn(async move {
118 let mut upstream = match connect(
119 self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
120 self.inner.info.verify_upstream_tls,
121 self.inner.info.app_protocol.clone(),
122 None, &url,
124 )
125 .await
126 {
127 Ok(conn) => conn,
128 Err(error) => {
129 #[cfg(feature = "hyper")]
130 if self.edge_type() == EdgeType::Https {
131 serve_gateway_error(format!("{error}"), self);
132 }
133 warn!(%error, "error connecting to upstream");
134 return Err(error);
135 }
136 };
137
138 copy_bidirectional(&mut self, &mut upstream).await?;
139 Ok(())
140 })
141 }
142}
143
144impl ConnExt for EndpointConn {
145 fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
146 let url = url.clone();
147 tokio::spawn(async move {
148 let proxy_proto = self.inner.info.proxy_proto;
149 let proto_tls = self.proto() == "tls";
150 #[cfg(feature = "hyper")]
151 let proto_http = matches!(self.proto(), "http" | "https");
152 let passthrough_tls = self.inner.info.passthrough_tls();
153 let app_protocol = self.inner.info.app_protocol.clone();
154 let verify_upstream_tls = self.inner.info.verify_upstream_tls;
155
156 let (mut stream, proxy_header) = match proxy_proto {
157 ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
158 _ => {
159 let mut stream = crate::proxy_proto::Stream::incoming(self);
160 let header = stream
161 .proxy_header()
162 .await?
163 .map_err(|e| {
164 io::Error::new(
165 io::ErrorKind::InvalidData,
166 format!("invalid proxy-protocol header: {}", e),
167 )
168 })?
169 .cloned();
170 (stream, header)
171 }
172 };
173
174 let mut upstream = match connect(
175 proto_tls && passthrough_tls,
176 verify_upstream_tls,
177 app_protocol,
178 proxy_header,
179 &url,
180 )
181 .await
182 {
183 Ok(conn) => conn,
184 Err(error) => {
185 #[cfg(feature = "hyper")]
186 if proto_http {
187 serve_gateway_error(format!("{error}"), stream);
188 }
189 warn!(%error, "error connecting to upstream");
190 return Err(error);
191 }
192 };
193
194 copy_bidirectional(&mut stream, &mut upstream).await?;
195 Ok(())
196 })
197 }
198}
199
200bitflags! {
201 struct TlsFlags: u8 {
202 const FLAG_HTTP2 = 0b01;
203 const FLAG_verify_upstream_tls = 0b10;
204 const FLAG_MAX = Self::FLAG_HTTP2.bits()
205 | Self::FLAG_verify_upstream_tls.bits();
206 }
207}
208
209fn tls_config(
210 app_protocol: Option<String>,
211 verify_upstream_tls: bool,
212) -> Result<Arc<ClientConfig>, &'static io::Error> {
213 #[allow(clippy::type_complexity)]
220 static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
221 Lazy::new(|| {
222 std::ops::Range {
223 start: 0,
224 end: TlsFlags::FLAG_MAX.bits() + 1,
225 }
226 .map(|p| {
227 let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
228 let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
229 let mut config = crate::session::host_certs_tls_config()?;
230 if !verify_upstream_tls {
231 config.dangerous().set_certificate_verifier(Arc::new(
232 danger::NoCertificateVerification::new(
233 rustls::crypto::aws_lc_rs::default_provider(),
234 ),
235 ));
236 }
237
238 if http2 {
239 config
240 .alpn_protocols
241 .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
242 }
243 Ok((p, Arc::new(config)))
244 })
245 .collect()
246 });
247
248 let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
249 let mut key = 0;
250 if Some("http2").eq(&app_protocol.as_deref()) {
251 key |= TlsFlags::FLAG_HTTP2.bits();
252 }
253 if verify_upstream_tls {
254 key |= TlsFlags::FLAG_verify_upstream_tls.bits();
255 }
256
257 Ok(configs
258 .get(&key)
259 .or_else(|| configs.get(&0))
260 .unwrap()
261 .clone())
262}
263
264async fn connect(
269 tunnel_tls: bool,
270 verify_upstream_tls: bool,
271 app_protocol: Option<String>,
272 proxy_proto_header: Option<ProxyHeader>,
273 url: &Url,
274) -> Result<Box<dyn IoStream>, io::Error> {
275 let host = url.host_str().unwrap_or("localhost");
276 let mut backend_tls: bool = false;
277 let mut conn: Box<dyn IoStream> = match url.scheme() {
278 "tcp" => {
279 let port = url.port().ok_or_else(|| {
280 io::Error::new(
281 io::ErrorKind::InvalidInput,
282 format!("missing port for tcp forwarding url {url}"),
283 )
284 })?;
285 let conn = connect_tcp(host, port).in_current_span().await?;
286 Box::new(conn)
287 }
288
289 "http" => {
290 let port = url.port().unwrap_or(80);
291 let conn = connect_tcp(host, port).in_current_span().await?;
292 Box::new(conn)
293 }
294
295 "https" | "tls" => {
296 let port = url.port().unwrap_or(443);
297 let conn = connect_tcp(host, port).in_current_span().await?;
298
299 backend_tls = true;
300 Box::new(conn)
301 }
302
303 #[cfg(not(target_os = "windows"))]
304 "unix" => {
305 use std::borrow::Cow;
306
307 use tokio::net::UnixStream;
308
309 let mut addr = Cow::Borrowed(url.path());
310 if let Some(host) = url.host_str() {
311 addr = Cow::Owned(format!("{host}{addr}"));
314 }
315 Box::new(UnixStream::connect(&*addr).await?)
316 }
317
318 #[cfg(target_os = "windows")]
319 "pipe" => {
320 use std::time::Duration;
321
322 use tokio::net::windows::named_pipe::ClientOptions;
323 use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
324
325 let mut pipe_name = url.path();
326 if url.host_str().is_some() {
327 pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
328 }
329 if pipe_name.is_empty() {
330 return Err(io::Error::new(
331 io::ErrorKind::InvalidInput,
332 format!("missing pipe name in forwarding url {url}"),
333 ));
334 }
335 let host = url
336 .host_str()
337 .map(|h| if h == "localhost" { "." } else { h })
339 .unwrap_or(".");
340 let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
342 let local_conn = loop {
345 match ClientOptions::new().open(&addr) {
346 Ok(client) => break client,
347 Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
348 Err(error) => return Err(error),
349 }
350
351 tokio::time::sleep(Duration::from_millis(50)).await;
352 };
353 Box::new(local_conn)
354 }
355 _ => {
356 return Err(io::Error::new(
357 io::ErrorKind::InvalidInput,
358 format!("unrecognized scheme in forwarding url: {url}"),
359 ))
360 }
361 };
362
363 if let Some(header) = proxy_proto_header {
365 conn = Box::new(
366 proxy_proto::Stream::outgoing(conn, header)
367 .expect("re-serializing proxy header should always succeed"),
368 )
369 };
370
371 if backend_tls && !tunnel_tls {
372 let domain = pki_types::ServerName::try_from(host)
373 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
374 .to_owned();
375 conn = Box::new(
376 futures_rustls::TlsConnector::from(
377 tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
378 )
379 .connect(domain, conn.compat())
380 .await?
381 .compat(),
382 )
383 }
384
385 Ok(conn)
388}
389
390async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
391 let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
392 if let Ok(addr) = conn.peer_addr() {
393 Span::current().record("forward_addr", field::display(addr));
394 }
395 Ok(conn)
396}
397
398#[cfg(feature = "hyper")]
399fn serve_gateway_error(
400 err: impl fmt::Display + Send + 'static,
401 conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
402) -> JoinHandle<()> {
403 tokio::spawn(
404 async move {
405 let service = service_fn(move |_req| {
406 debug!("serving bad gateway error");
407 let mut resp = Response::new(format!("failed to dial backend: {err}"));
408 *resp.status_mut() = StatusCode::BAD_GATEWAY;
409 futures::future::ok::<_, Infallible>(resp)
410 });
411
412 let res = http1::Builder::new()
413 .keep_alive(false)
414 .serve_connection(conn, service)
415 .await;
416 debug!(?res, "connection closed");
417 }
418 .in_current_span(),
419 )
420}
421
422mod danger {
424 use futures_rustls::rustls;
425 use rustls::{
426 client::danger::HandshakeSignatureValid,
427 crypto::{
428 verify_tls12_signature,
429 verify_tls13_signature,
430 CryptoProvider,
431 },
432 DigitallySignedStruct,
433 };
434
435 use super::pki_types::{
436 CertificateDer,
437 ServerName,
438 UnixTime,
439 };
440
441 #[derive(Debug)]
442 pub struct NoCertificateVerification(CryptoProvider);
443
444 impl NoCertificateVerification {
445 pub fn new(provider: CryptoProvider) -> Self {
446 Self(provider)
447 }
448 }
449
450 impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
451 fn verify_server_cert(
452 &self,
453 _end_entity: &CertificateDer<'_>,
454 _intermediates: &[CertificateDer<'_>],
455 _server_name: &ServerName<'_>,
456 _ocsp: &[u8],
457 _now: UnixTime,
458 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
459 Ok(rustls::client::danger::ServerCertVerified::assertion())
460 }
461
462 fn verify_tls12_signature(
463 &self,
464 message: &[u8],
465 cert: &CertificateDer<'_>,
466 dss: &DigitallySignedStruct,
467 ) -> Result<HandshakeSignatureValid, rustls::Error> {
468 verify_tls12_signature(
469 message,
470 cert,
471 dss,
472 &self.0.signature_verification_algorithms,
473 )
474 }
475
476 fn verify_tls13_signature(
477 &self,
478 message: &[u8],
479 cert: &CertificateDer<'_>,
480 dss: &DigitallySignedStruct,
481 ) -> Result<HandshakeSignatureValid, rustls::Error> {
482 verify_tls13_signature(
483 message,
484 cert,
485 dss,
486 &self.0.signature_verification_algorithms,
487 )
488 }
489
490 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
491 self.0.signature_verification_algorithms.supported_schemes()
492 }
493 }
494}