1use std::{
2 collections::HashMap,
3 io,
4 sync::Arc,
5};
6#[cfg(feature = "hyper")]
7use std::{
8 convert::Infallible,
9 fmt,
10};
11
12use async_trait::async_trait;
13use bitflags::bitflags;
14use futures::stream::TryStreamExt;
15use futures_rustls::rustls::{
16 self,
17 pki_types,
18 ClientConfig,
19};
20#[cfg(feature = "hyper")]
21use hyper::{
22 server::conn::http1,
23 service::service_fn,
24 Response,
25 StatusCode,
26};
27use once_cell::sync::Lazy;
28use proxy_protocol::ProxyHeader;
29use tokio::{
30 io::copy_bidirectional,
31 net::TcpStream,
32 task::JoinHandle,
33};
34use tokio_util::compat::{
35 FuturesAsyncReadCompatExt,
36 TokioAsyncReadCompatExt,
37};
38#[cfg(feature = "hyper")]
39use tracing::debug;
40use tracing::{
41 field,
42 warn,
43 Instrument,
44 Span,
45};
46use url::Url;
47
48use crate::{
49 prelude::*,
50 proxy_proto,
51 session::IoStream,
52 EdgeConn,
53 EndpointConn,
54};
55
56#[allow(deprecated)]
57#[async_trait]
58impl<T> TunnelExt for T
59where
60 T: Tunnel + Send,
61 <T as Tunnel>::Conn: ConnExt,
62{
63 async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
64 forward_tunnel(self, url).await
65 }
66}
67
68#[async_trait]
70#[deprecated = "superceded by the `listen_and_forward` builder method"]
71pub trait TunnelExt: Tunnel + Send {
72 async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
87}
88
89pub(crate) trait ConnExt {
90 fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
91}
92
93#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
94pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
95where
96 T: Tunnel + 'static + ?Sized,
97 <T as Tunnel>::Conn: ConnExt,
98{
99 loop {
100 let tunnel_conn = if let Some(conn) = tun
101 .try_next()
102 .await
103 .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
104 {
105 conn
106 } else {
107 return Ok(());
108 };
109
110 tunnel_conn.forward_to(&url);
111 }
112}
113
114impl ConnExt for EdgeConn {
115 fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
116 let url = url.clone();
117 tokio::spawn(async move {
118 let mut upstream = match connect(
119 self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
120 self.inner.info.verify_upstream_tls,
121 self.inner.info.app_protocol.clone(),
122 None, &url,
124 )
125 .await
126 {
127 Ok(conn) => conn,
128 Err(error) => {
129 #[cfg(feature = "hyper")]
130 if self.edge_type() == EdgeType::Https {
131 serve_gateway_error(format!("{error}"), self);
132 }
133 warn!(%error, "error connecting to upstream");
134 return Err(error);
135 }
136 };
137
138 copy_bidirectional(&mut self, &mut upstream).await?;
139 Ok(())
140 })
141 }
142}
143
144impl ConnExt for EndpointConn {
145 fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
146 let url = url.clone();
147 tokio::spawn(async move {
148 let proxy_proto = self.inner.info.proxy_proto;
149 let proto_tls = self.proto() == "tls";
150 #[cfg(feature = "hyper")]
151 let proto_http = matches!(self.proto(), "http" | "https");
152 let passthrough_tls = self.inner.info.passthrough_tls();
153 let app_protocol = self.inner.info.app_protocol.clone();
154 let verify_upstream_tls = self.inner.info.verify_upstream_tls;
155
156 let (mut stream, proxy_header) = match proxy_proto {
157 ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
158 _ => {
159 let mut stream = crate::proxy_proto::Stream::incoming(self);
160 let header = stream
161 .proxy_header()
162 .await?
163 .map_err(|e| {
164 io::Error::new(
165 io::ErrorKind::InvalidData,
166 format!("invalid proxy-protocol header: {}", e),
167 )
168 })?
169 .cloned();
170 (stream, header)
171 }
172 };
173
174 let mut upstream = match connect(
175 proto_tls && passthrough_tls,
176 verify_upstream_tls,
177 app_protocol,
178 proxy_header,
179 &url,
180 )
181 .await
182 {
183 Ok(conn) => conn,
184 Err(error) => {
185 #[cfg(feature = "hyper")]
186 if proto_http {
187 serve_gateway_error(format!("{error}"), stream);
188 }
189 warn!(%error, "error connecting to upstream");
190 return Err(error);
191 }
192 };
193
194 copy_bidirectional(&mut stream, &mut upstream).await?;
195 Ok(())
196 })
197 }
198}
199
200bitflags! {
201 struct TlsFlags: u8 {
202 const FLAG_HTTP2 = 0b01;
203 const FLAG_verify_upstream_tls = 0b10;
204 const FLAG_MAX = Self::FLAG_HTTP2.bits()
205 | Self::FLAG_verify_upstream_tls.bits();
206 }
207}
208
209static NO_CRYPTO_PROVIDER_ERROR: Lazy<io::Error> = Lazy::new(|| {
210 io::Error::new(
211 io::ErrorKind::NotFound,
212 "no default CryptoProvider installed",
213 )
214});
215
216fn tls_config(
217 app_protocol: Option<String>,
218 verify_upstream_tls: bool,
219) -> Result<Arc<ClientConfig>, &'static io::Error> {
220 #[allow(clippy::type_complexity)]
227 static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
228 Lazy::new(|| {
229 std::ops::Range {
230 start: 0,
231 end: TlsFlags::FLAG_MAX.bits() + 1,
232 }
233 .map(|p| {
234 let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
235 let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
236 let mut config = crate::session::host_certs_tls_config()?;
237 if !verify_upstream_tls {
238 let provider = rustls::crypto::CryptoProvider::get_default()
239 .ok_or(&*NO_CRYPTO_PROVIDER_ERROR)?
240 .as_ref()
241 .clone();
242 config.dangerous().set_certificate_verifier(Arc::new(
243 danger::NoCertificateVerification::new(provider),
244 ));
245 }
246
247 if http2 {
248 config
249 .alpn_protocols
250 .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
251 }
252 Ok((p, Arc::new(config)))
253 })
254 .collect()
255 });
256
257 let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
258 let mut key = 0;
259 if Some("http2").eq(&app_protocol.as_deref()) {
260 key |= TlsFlags::FLAG_HTTP2.bits();
261 }
262 if verify_upstream_tls {
263 key |= TlsFlags::FLAG_verify_upstream_tls.bits();
264 }
265
266 Ok(configs
267 .get(&key)
268 .or_else(|| configs.get(&0))
269 .unwrap()
270 .clone())
271}
272
273async fn connect(
278 tunnel_tls: bool,
279 verify_upstream_tls: bool,
280 app_protocol: Option<String>,
281 proxy_proto_header: Option<ProxyHeader>,
282 url: &Url,
283) -> Result<Box<dyn IoStream>, io::Error> {
284 let host = url.host_str().unwrap_or("localhost");
285 let mut backend_tls: bool = false;
286 let mut conn: Box<dyn IoStream> = match url.scheme() {
287 "tcp" => {
288 let port = url.port().ok_or_else(|| {
289 io::Error::new(
290 io::ErrorKind::InvalidInput,
291 format!("missing port for tcp forwarding url {url}"),
292 )
293 })?;
294 let conn = connect_tcp(host, port).in_current_span().await?;
295 Box::new(conn)
296 }
297
298 "http" => {
299 let port = url.port().unwrap_or(80);
300 let conn = connect_tcp(host, port).in_current_span().await?;
301 Box::new(conn)
302 }
303
304 "https" | "tls" => {
305 let port = url.port().unwrap_or(443);
306 let conn = connect_tcp(host, port).in_current_span().await?;
307
308 backend_tls = true;
309 Box::new(conn)
310 }
311
312 #[cfg(not(target_os = "windows"))]
313 "unix" => {
314 use std::borrow::Cow;
315
316 use tokio::net::UnixStream;
317
318 let mut addr = Cow::Borrowed(url.path());
319 if let Some(host) = url.host_str() {
320 addr = Cow::Owned(format!("{host}{addr}"));
323 }
324 Box::new(UnixStream::connect(&*addr).await?)
325 }
326
327 #[cfg(target_os = "windows")]
328 "pipe" => {
329 use std::time::Duration;
330
331 use tokio::net::windows::named_pipe::ClientOptions;
332 use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
333
334 let mut pipe_name = url.path();
335 if url.host_str().is_some() {
336 pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
337 }
338 if pipe_name.is_empty() {
339 return Err(io::Error::new(
340 io::ErrorKind::InvalidInput,
341 format!("missing pipe name in forwarding url {url}"),
342 ));
343 }
344 let host = url
345 .host_str()
346 .map(|h| if h == "localhost" { "." } else { h })
348 .unwrap_or(".");
349 let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
351 let local_conn = loop {
354 match ClientOptions::new().open(&addr) {
355 Ok(client) => break client,
356 Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
357 Err(error) => return Err(error),
358 }
359
360 tokio::time::sleep(Duration::from_millis(50)).await;
361 };
362 Box::new(local_conn)
363 }
364 _ => {
365 return Err(io::Error::new(
366 io::ErrorKind::InvalidInput,
367 format!("unrecognized scheme in forwarding url: {url}"),
368 ))
369 }
370 };
371
372 if let Some(header) = proxy_proto_header {
374 conn = Box::new(
375 proxy_proto::Stream::outgoing(conn, header)
376 .expect("re-serializing proxy header should always succeed"),
377 )
378 };
379
380 if backend_tls && !tunnel_tls {
381 let domain = pki_types::ServerName::try_from(host)
382 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
383 .to_owned();
384 conn = Box::new(
385 futures_rustls::TlsConnector::from(
386 tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
387 )
388 .connect(domain, conn.compat())
389 .await?
390 .compat(),
391 )
392 }
393
394 Ok(conn)
397}
398
399async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
400 let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
401 if let Ok(addr) = conn.peer_addr() {
402 Span::current().record("forward_addr", field::display(addr));
403 }
404 Ok(conn)
405}
406
407#[cfg(feature = "hyper")]
408fn serve_gateway_error(
409 err: impl fmt::Display + Send + 'static,
410 conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
411) -> JoinHandle<()> {
412 tokio::spawn(
413 async move {
414 let service = service_fn(move |_req| {
415 debug!("serving bad gateway error");
416 let mut resp = Response::new(format!("failed to dial backend: {err}"));
417 *resp.status_mut() = StatusCode::BAD_GATEWAY;
418 futures::future::ok::<_, Infallible>(resp)
419 });
420
421 let res = http1::Builder::new()
422 .keep_alive(false)
423 .serve_connection(conn, service)
424 .await;
425 debug!(?res, "connection closed");
426 }
427 .in_current_span(),
428 )
429}
430
431mod danger {
433 use futures_rustls::rustls;
434 use rustls::{
435 client::danger::HandshakeSignatureValid,
436 crypto::{
437 verify_tls12_signature,
438 verify_tls13_signature,
439 CryptoProvider,
440 },
441 DigitallySignedStruct,
442 };
443
444 use super::pki_types::{
445 CertificateDer,
446 ServerName,
447 UnixTime,
448 };
449
450 #[derive(Debug)]
451 pub struct NoCertificateVerification(CryptoProvider);
452
453 impl NoCertificateVerification {
454 pub fn new(provider: CryptoProvider) -> Self {
455 Self(provider)
456 }
457 }
458
459 impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
460 fn verify_server_cert(
461 &self,
462 _end_entity: &CertificateDer<'_>,
463 _intermediates: &[CertificateDer<'_>],
464 _server_name: &ServerName<'_>,
465 _ocsp: &[u8],
466 _now: UnixTime,
467 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
468 Ok(rustls::client::danger::ServerCertVerified::assertion())
469 }
470
471 fn verify_tls12_signature(
472 &self,
473 message: &[u8],
474 cert: &CertificateDer<'_>,
475 dss: &DigitallySignedStruct,
476 ) -> Result<HandshakeSignatureValid, rustls::Error> {
477 verify_tls12_signature(
478 message,
479 cert,
480 dss,
481 &self.0.signature_verification_algorithms,
482 )
483 }
484
485 fn verify_tls13_signature(
486 &self,
487 message: &[u8],
488 cert: &CertificateDer<'_>,
489 dss: &DigitallySignedStruct,
490 ) -> Result<HandshakeSignatureValid, rustls::Error> {
491 verify_tls13_signature(
492 message,
493 cert,
494 dss,
495 &self.0.signature_verification_algorithms,
496 )
497 }
498
499 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
500 self.0.signature_verification_algorithms.supported_schemes()
501 }
502 }
503}