1#[cfg(not(target_os = "windows"))]
2use std::borrow::Cow;
3#[cfg(target_os = "windows")]
4use std::time::Duration;
5use std::{
6 collections::HashMap,
7 io,
8 sync::Arc,
9};
10#[cfg(feature = "hyper")]
11use std::{
12 convert::Infallible,
13 fmt,
14};
15
16use async_trait::async_trait;
17use bitflags::bitflags;
18use futures::stream::TryStreamExt;
19use futures_rustls::rustls::{
20 self,
21 pki_types,
22 ClientConfig,
23};
24#[cfg(feature = "hyper")]
25use hyper::{
26 server::conn::http1,
27 service::service_fn,
28 Response,
29 StatusCode,
30};
31use once_cell::sync::Lazy;
32use proxy_protocol::ProxyHeader;
33#[cfg(feature = "hyper")]
34#[cfg(target_os = "windows")]
35use tokio::net::windows::named_pipe::ClientOptions;
36#[cfg(not(target_os = "windows"))]
37use tokio::net::UnixStream;
38#[cfg(target_os = "windows")]
39use tokio::time;
40use tokio::{
41 io::copy_bidirectional,
42 net::TcpStream,
43 task::JoinHandle,
44};
45use tokio_util::compat::{
46 FuturesAsyncReadCompatExt,
47 TokioAsyncReadCompatExt,
48};
49#[cfg(feature = "hyper")]
50use tracing::debug;
51use tracing::{
52 field,
53 warn,
54 Instrument,
55 Span,
56};
57use url::Url;
58#[cfg(target_os = "windows")]
59use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
60
61use crate::{
62 prelude::*,
63 proxy_proto,
64 session::IoStream,
65 EdgeConn,
66 EndpointConn,
67};
68
69#[allow(deprecated)]
70#[async_trait]
71impl<T> TunnelExt for T
72where
73 T: Tunnel + Send,
74 <T as Tunnel>::Conn: ConnExt,
75{
76 async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
77 forward_tunnel(self, url).await
78 }
79}
80
81#[async_trait]
83#[deprecated = "superceded by the `listen_and_forward` builder method"]
84pub trait TunnelExt: Tunnel + Send {
85 async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
100}
101
102pub(crate) trait ConnExt {
103 fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
104}
105
106#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
107pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
108where
109 T: Tunnel + 'static + ?Sized,
110 <T as Tunnel>::Conn: ConnExt,
111{
112 loop {
113 let tunnel_conn = if let Some(conn) = tun
114 .try_next()
115 .await
116 .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
117 {
118 conn
119 } else {
120 return Ok(());
121 };
122
123 tunnel_conn.forward_to(&url);
124 }
125}
126
127impl ConnExt for EdgeConn {
128 fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
129 let url = url.clone();
130 tokio::spawn(async move {
131 let mut upstream = match connect(
132 self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
133 self.inner.info.verify_upstream_tls,
134 self.inner.info.app_protocol.clone(),
135 None, &url,
137 )
138 .await
139 {
140 Ok(conn) => conn,
141 Err(error) => {
142 #[cfg(feature = "hyper")]
143 if self.edge_type() == EdgeType::Https {
144 serve_gateway_error(format!("{error}"), self);
145 }
146 warn!(%error, "error connecting to upstream");
147 return Err(error);
148 }
149 };
150
151 copy_bidirectional(&mut self, &mut upstream).await?;
152 Ok(())
153 })
154 }
155}
156
157impl ConnExt for EndpointConn {
158 fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
159 let url = url.clone();
160 tokio::spawn(async move {
161 let proxy_proto = self.inner.info.proxy_proto;
162 let proto_tls = self.proto() == "tls";
163 #[cfg(feature = "hyper")]
164 let proto_http = matches!(self.proto(), "http" | "https");
165 let passthrough_tls = self.inner.info.passthrough_tls();
166 let app_protocol = self.inner.info.app_protocol.clone();
167 let verify_upstream_tls = self.inner.info.verify_upstream_tls;
168
169 let (mut stream, proxy_header) = match proxy_proto {
170 ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
171 _ => {
172 let mut stream = crate::proxy_proto::Stream::incoming(self);
173 let header = stream
174 .proxy_header()
175 .await?
176 .map_err(|e| {
177 io::Error::new(
178 io::ErrorKind::InvalidData,
179 format!("invalid proxy-protocol header: {}", e),
180 )
181 })?
182 .cloned();
183 (stream, header)
184 }
185 };
186
187 let mut upstream = match connect(
188 proto_tls && passthrough_tls,
189 verify_upstream_tls,
190 app_protocol,
191 proxy_header,
192 &url,
193 )
194 .await
195 {
196 Ok(conn) => conn,
197 Err(error) => {
198 #[cfg(feature = "hyper")]
199 if proto_http {
200 serve_gateway_error(format!("{error}"), stream);
201 }
202 warn!(%error, "error connecting to upstream");
203 return Err(error);
204 }
205 };
206
207 copy_bidirectional(&mut stream, &mut upstream).await?;
208 Ok(())
209 })
210 }
211}
212
213bitflags! {
214 struct TlsFlags: u8 {
215 const FLAG_HTTP2 = 0b01;
216 const FLAG_verify_upstream_tls = 0b10;
217 const FLAG_MAX = Self::FLAG_HTTP2.bits()
218 | Self::FLAG_verify_upstream_tls.bits();
219 }
220}
221
222fn tls_config(
223 app_protocol: Option<String>,
224 verify_upstream_tls: bool,
225) -> Result<Arc<ClientConfig>, &'static io::Error> {
226 #[allow(clippy::type_complexity)]
233 static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
234 Lazy::new(|| {
235 std::ops::Range {
236 start: 0,
237 end: TlsFlags::FLAG_MAX.bits() + 1,
238 }
239 .map(|p| {
240 let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
241 let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
242 let mut config = crate::session::host_certs_tls_config()?;
243 if !verify_upstream_tls {
244 config.dangerous().set_certificate_verifier(Arc::new(
245 danger::NoCertificateVerification::new(
246 rustls::crypto::aws_lc_rs::default_provider(),
247 ),
248 ));
249 }
250
251 if http2 {
252 config
253 .alpn_protocols
254 .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
255 }
256 Ok((p, Arc::new(config)))
257 })
258 .collect()
259 });
260
261 let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
262 let mut key = 0;
263 if Some("http2").eq(&app_protocol.as_deref()) {
264 key |= TlsFlags::FLAG_HTTP2.bits();
265 }
266 if verify_upstream_tls {
267 key |= TlsFlags::FLAG_verify_upstream_tls.bits();
268 }
269
270 Ok(configs
271 .get(&key)
272 .or_else(|| configs.get(&0))
273 .unwrap()
274 .clone())
275}
276
277async fn connect(
282 tunnel_tls: bool,
283 verify_upstream_tls: bool,
284 app_protocol: Option<String>,
285 proxy_proto_header: Option<ProxyHeader>,
286 url: &Url,
287) -> Result<Box<dyn IoStream>, io::Error> {
288 let host = url.host_str().unwrap_or("localhost");
289 let mut backend_tls: bool = false;
290 let mut conn: Box<dyn IoStream> = match url.scheme() {
291 "tcp" => {
292 let port = url.port().ok_or_else(|| {
293 io::Error::new(
294 io::ErrorKind::InvalidInput,
295 format!("missing port for tcp forwarding url {url}"),
296 )
297 })?;
298 let conn = connect_tcp(host, port).in_current_span().await?;
299 Box::new(conn)
300 }
301
302 "http" => {
303 let port = url.port().unwrap_or(80);
304 let conn = connect_tcp(host, port).in_current_span().await?;
305 Box::new(conn)
306 }
307
308 "https" | "tls" => {
309 let port = url.port().unwrap_or(443);
310 let conn = connect_tcp(host, port).in_current_span().await?;
311
312 backend_tls = true;
313 Box::new(conn)
314 }
315
316 #[cfg(not(target_os = "windows"))]
317 "unix" => {
318 let mut addr = Cow::Borrowed(url.path());
320 if let Some(host) = url.host_str() {
321 addr = Cow::Owned(format!("{host}{addr}"));
324 }
325 Box::new(UnixStream::connect(&*addr).await?)
326 }
327
328 #[cfg(target_os = "windows")]
329 "pipe" => {
330 let mut pipe_name = url.path();
331 if url.host_str().is_some() {
332 pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
333 }
334 if pipe_name.is_empty() {
335 return Err(io::Error::new(
336 io::ErrorKind::InvalidInput,
337 format!("missing pipe name in forwarding url {url}"),
338 ));
339 }
340 let host = url
341 .host_str()
342 .map(|h| if h == "localhost" { "." } else { h })
344 .unwrap_or(".");
345 let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
347 let local_conn = loop {
350 match ClientOptions::new().open(&addr) {
351 Ok(client) => break client,
352 Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
353 Err(error) => return Err(error),
354 }
355
356 time::sleep(Duration::from_millis(50)).await;
357 };
358 Box::new(local_conn)
359 }
360 _ => {
361 return Err(io::Error::new(
362 io::ErrorKind::InvalidInput,
363 format!("unrecognized scheme in forwarding url: {url}"),
364 ))
365 }
366 };
367
368 if let Some(header) = proxy_proto_header {
370 conn = Box::new(
371 proxy_proto::Stream::outgoing(conn, header)
372 .expect("re-serializing proxy header should always succeed"),
373 )
374 };
375
376 if backend_tls && !tunnel_tls {
377 let domain = pki_types::ServerName::try_from(host)
378 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
379 .to_owned();
380 conn = Box::new(
381 futures_rustls::TlsConnector::from(
382 tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
383 )
384 .connect(domain, conn.compat())
385 .await?
386 .compat(),
387 )
388 }
389
390 Ok(conn)
393}
394
395async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
396 let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
397 if let Ok(addr) = conn.peer_addr() {
398 Span::current().record("forward_addr", field::display(addr));
399 }
400 Ok(conn)
401}
402
403#[cfg(feature = "hyper")]
404fn serve_gateway_error(
405 err: impl fmt::Display + Send + 'static,
406 conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
407) -> JoinHandle<()> {
408 tokio::spawn(
409 async move {
410 let service = service_fn(move |_req| {
411 debug!("serving bad gateway error");
412 let mut resp = Response::new(format!("failed to dial backend: {err}"));
413 *resp.status_mut() = StatusCode::BAD_GATEWAY;
414 futures::future::ok::<_, Infallible>(resp)
415 });
416
417 let res = http1::Builder::new()
418 .keep_alive(false)
419 .serve_connection(conn, service)
420 .await;
421 debug!(?res, "connection closed");
422 }
423 .in_current_span(),
424 )
425}
426
427mod danger {
429 use futures_rustls::rustls;
430 use rustls::{
431 client::danger::HandshakeSignatureValid,
432 crypto::{
433 verify_tls12_signature,
434 verify_tls13_signature,
435 CryptoProvider,
436 },
437 DigitallySignedStruct,
438 };
439
440 use super::pki_types::{
441 CertificateDer,
442 ServerName,
443 UnixTime,
444 };
445
446 #[derive(Debug)]
447 pub struct NoCertificateVerification(CryptoProvider);
448
449 impl NoCertificateVerification {
450 pub fn new(provider: CryptoProvider) -> Self {
451 Self(provider)
452 }
453 }
454
455 impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
456 fn verify_server_cert(
457 &self,
458 _end_entity: &CertificateDer<'_>,
459 _intermediates: &[CertificateDer<'_>],
460 _server_name: &ServerName<'_>,
461 _ocsp: &[u8],
462 _now: UnixTime,
463 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
464 Ok(rustls::client::danger::ServerCertVerified::assertion())
465 }
466
467 fn verify_tls12_signature(
468 &self,
469 message: &[u8],
470 cert: &CertificateDer<'_>,
471 dss: &DigitallySignedStruct,
472 ) -> Result<HandshakeSignatureValid, rustls::Error> {
473 verify_tls12_signature(
474 message,
475 cert,
476 dss,
477 &self.0.signature_verification_algorithms,
478 )
479 }
480
481 fn verify_tls13_signature(
482 &self,
483 message: &[u8],
484 cert: &CertificateDer<'_>,
485 dss: &DigitallySignedStruct,
486 ) -> Result<HandshakeSignatureValid, rustls::Error> {
487 verify_tls13_signature(
488 message,
489 cert,
490 dss,
491 &self.0.signature_verification_algorithms,
492 )
493 }
494
495 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
496 self.0.signature_verification_algorithms.supported_schemes()
497 }
498 }
499}