ngrok/
tunnel_ext.rs

1#[cfg(not(target_os = "windows"))]
2use std::borrow::Cow;
3#[cfg(target_os = "windows")]
4use std::time::Duration;
5use std::{
6    collections::HashMap,
7    io,
8    sync::Arc,
9};
10#[cfg(feature = "hyper")]
11use std::{
12    convert::Infallible,
13    fmt,
14};
15
16use async_trait::async_trait;
17use bitflags::bitflags;
18use futures::stream::TryStreamExt;
19use futures_rustls::rustls::{
20    self,
21    pki_types,
22    ClientConfig,
23};
24#[cfg(feature = "hyper")]
25use hyper::{
26    server::conn::http1,
27    service::service_fn,
28    Response,
29    StatusCode,
30};
31use once_cell::sync::Lazy;
32use proxy_protocol::ProxyHeader;
33#[cfg(feature = "hyper")]
34#[cfg(target_os = "windows")]
35use tokio::net::windows::named_pipe::ClientOptions;
36#[cfg(not(target_os = "windows"))]
37use tokio::net::UnixStream;
38#[cfg(target_os = "windows")]
39use tokio::time;
40use tokio::{
41    io::copy_bidirectional,
42    net::TcpStream,
43    task::JoinHandle,
44};
45use tokio_util::compat::{
46    FuturesAsyncReadCompatExt,
47    TokioAsyncReadCompatExt,
48};
49#[cfg(feature = "hyper")]
50use tracing::debug;
51use tracing::{
52    field,
53    warn,
54    Instrument,
55    Span,
56};
57use url::Url;
58#[cfg(target_os = "windows")]
59use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
60
61use crate::{
62    prelude::*,
63    proxy_proto,
64    session::IoStream,
65    EdgeConn,
66    EndpointConn,
67};
68
69#[allow(deprecated)]
70#[async_trait]
71impl<T> TunnelExt for T
72where
73    T: Tunnel + Send,
74    <T as Tunnel>::Conn: ConnExt,
75{
76    async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
77        forward_tunnel(self, url).await
78    }
79}
80
81/// Extension methods auto-implemented for all tunnel types
82#[async_trait]
83#[deprecated = "superceded by the `listen_and_forward` builder method"]
84pub trait TunnelExt: Tunnel + Send {
85    /// Forward incoming tunnel connections to the provided url based on its
86    /// scheme.
87    /// This currently supports http, https, tls, and tcp on all platforms, unix
88    /// sockets on unix platforms, and named pipes on Windows via the "pipe"
89    /// scheme.
90    ///
91    /// Unix socket URLs can be formatted as `unix://path/to/socket` or
92    /// `unix:path/to/socket` for relative paths or as `unix:///path/to/socket` or
93    /// `unix:/path/to/socket` for absolute paths.
94    ///
95    /// Windows named pipe URLs can be formatted as `pipe:mypipename` or
96    /// `pipe://host/mypipename`. If no host is provided, as with
97    /// `pipe:///mypipename` or `pipe:/mypipename`, the leading slash will be
98    /// preserved.
99    async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
100}
101
102pub(crate) trait ConnExt {
103    fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
104}
105
106#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
107pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
108where
109    T: Tunnel + 'static + ?Sized,
110    <T as Tunnel>::Conn: ConnExt,
111{
112    loop {
113        let tunnel_conn = if let Some(conn) = tun
114            .try_next()
115            .await
116            .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
117        {
118            conn
119        } else {
120            return Ok(());
121        };
122
123        tunnel_conn.forward_to(&url);
124    }
125}
126
127impl ConnExt for EdgeConn {
128    fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
129        let url = url.clone();
130        tokio::spawn(async move {
131            let mut upstream = match connect(
132                self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
133                self.inner.info.verify_upstream_tls,
134                self.inner.info.app_protocol.clone(),
135                None, // Edges don't support proxyproto (afaik)
136                &url,
137            )
138            .await
139            {
140                Ok(conn) => conn,
141                Err(error) => {
142                    #[cfg(feature = "hyper")]
143                    if self.edge_type() == EdgeType::Https {
144                        serve_gateway_error(format!("{error}"), self);
145                    }
146                    warn!(%error, "error connecting to upstream");
147                    return Err(error);
148                }
149            };
150
151            copy_bidirectional(&mut self, &mut upstream).await?;
152            Ok(())
153        })
154    }
155}
156
157impl ConnExt for EndpointConn {
158    fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
159        let url = url.clone();
160        tokio::spawn(async move {
161            let proxy_proto = self.inner.info.proxy_proto;
162            let proto_tls = self.proto() == "tls";
163            #[cfg(feature = "hyper")]
164            let proto_http = matches!(self.proto(), "http" | "https");
165            let passthrough_tls = self.inner.info.passthrough_tls();
166            let app_protocol = self.inner.info.app_protocol.clone();
167            let verify_upstream_tls = self.inner.info.verify_upstream_tls;
168
169            let (mut stream, proxy_header) = match proxy_proto {
170                ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
171                _ => {
172                    let mut stream = crate::proxy_proto::Stream::incoming(self);
173                    let header = stream
174                        .proxy_header()
175                        .await?
176                        .map_err(|e| {
177                            io::Error::new(
178                                io::ErrorKind::InvalidData,
179                                format!("invalid proxy-protocol header: {}", e),
180                            )
181                        })?
182                        .cloned();
183                    (stream, header)
184                }
185            };
186
187            let mut upstream = match connect(
188                proto_tls && passthrough_tls,
189                verify_upstream_tls,
190                app_protocol,
191                proxy_header,
192                &url,
193            )
194            .await
195            {
196                Ok(conn) => conn,
197                Err(error) => {
198                    #[cfg(feature = "hyper")]
199                    if proto_http {
200                        serve_gateway_error(format!("{error}"), stream);
201                    }
202                    warn!(%error, "error connecting to upstream");
203                    return Err(error);
204                }
205            };
206
207            copy_bidirectional(&mut stream, &mut upstream).await?;
208            Ok(())
209        })
210    }
211}
212
213bitflags! {
214    struct TlsFlags: u8 {
215        const FLAG_HTTP2       = 0b01;
216        const FLAG_verify_upstream_tls       = 0b10;
217        const FLAG_MAX     = Self::FLAG_HTTP2.bits()
218                           | Self::FLAG_verify_upstream_tls.bits();
219    }
220}
221
222fn tls_config(
223    app_protocol: Option<String>,
224    verify_upstream_tls: bool,
225) -> Result<Arc<ClientConfig>, &'static io::Error> {
226    // A hashmap of tls client configs for different configurations.
227    // There won't need to be a lot of variation among these, and we'll want to
228    // reuse them as much as we can, which is why we initialize them all once
229    // and then pull out the one we need.
230    // Disabling the lint because this is a local static that doesn't escape the
231    // enclosing context. It fine.
232    #[allow(clippy::type_complexity)]
233    static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
234        Lazy::new(|| {
235            std::ops::Range {
236                start: 0,
237                end: TlsFlags::FLAG_MAX.bits() + 1,
238            }
239            .map(|p| {
240                let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
241                let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
242                let mut config = crate::session::host_certs_tls_config()?;
243                if !verify_upstream_tls {
244                    config.dangerous().set_certificate_verifier(Arc::new(
245                        danger::NoCertificateVerification::new(
246                            rustls::crypto::aws_lc_rs::default_provider(),
247                        ),
248                    ));
249                }
250
251                if http2 {
252                    config
253                        .alpn_protocols
254                        .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
255                }
256                Ok((p, Arc::new(config)))
257            })
258            .collect()
259        });
260
261    let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
262    let mut key = 0;
263    if Some("http2").eq(&app_protocol.as_deref()) {
264        key |= TlsFlags::FLAG_HTTP2.bits();
265    }
266    if verify_upstream_tls {
267        key |= TlsFlags::FLAG_verify_upstream_tls.bits();
268    }
269
270    Ok(configs
271        .get(&key)
272        .or_else(|| configs.get(&0))
273        .unwrap()
274        .clone())
275}
276
277// Establish the connection to forward the tunnel stream to.
278// Takes the tunnel and connection to make additional decisions on how to wrap
279// the forwarded connection, i.e. reordering tls termination and proxyproto.
280// Note: this additional wrapping logic currently unimplemented.
281async fn connect(
282    tunnel_tls: bool,
283    verify_upstream_tls: bool,
284    app_protocol: Option<String>,
285    proxy_proto_header: Option<ProxyHeader>,
286    url: &Url,
287) -> Result<Box<dyn IoStream>, io::Error> {
288    let host = url.host_str().unwrap_or("localhost");
289    let mut backend_tls: bool = false;
290    let mut conn: Box<dyn IoStream> = match url.scheme() {
291        "tcp" => {
292            let port = url.port().ok_or_else(|| {
293                io::Error::new(
294                    io::ErrorKind::InvalidInput,
295                    format!("missing port for tcp forwarding url {url}"),
296                )
297            })?;
298            let conn = connect_tcp(host, port).in_current_span().await?;
299            Box::new(conn)
300        }
301
302        "http" => {
303            let port = url.port().unwrap_or(80);
304            let conn = connect_tcp(host, port).in_current_span().await?;
305            Box::new(conn)
306        }
307
308        "https" | "tls" => {
309            let port = url.port().unwrap_or(443);
310            let conn = connect_tcp(host, port).in_current_span().await?;
311
312            backend_tls = true;
313            Box::new(conn)
314        }
315
316        #[cfg(not(target_os = "windows"))]
317        "unix" => {
318            //
319            let mut addr = Cow::Borrowed(url.path());
320            if let Some(host) = url.host_str() {
321                // note: if host exists, there should always be a leading / in
322                // the path, but we should consider it a relative path.
323                addr = Cow::Owned(format!("{host}{addr}"));
324            }
325            Box::new(UnixStream::connect(&*addr).await?)
326        }
327
328        #[cfg(target_os = "windows")]
329        "pipe" => {
330            let mut pipe_name = url.path();
331            if url.host_str().is_some() {
332                pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
333            }
334            if pipe_name.is_empty() {
335                return Err(io::Error::new(
336                    io::ErrorKind::InvalidInput,
337                    format!("missing pipe name in forwarding url {url}"),
338                ));
339            }
340            let host = url
341                .host_str()
342                // Consider localhost to mean "." for the pipe name
343                .map(|h| if h == "localhost" { "." } else { h })
344                .unwrap_or(".");
345            // Finally, assemble the full name.
346            let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
347            // loop behavior copied from docs
348            // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
349            let local_conn = loop {
350                match ClientOptions::new().open(&addr) {
351                    Ok(client) => break client,
352                    Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
353                    Err(error) => return Err(error),
354                }
355
356                time::sleep(Duration::from_millis(50)).await;
357            };
358            Box::new(local_conn)
359        }
360        _ => {
361            return Err(io::Error::new(
362                io::ErrorKind::InvalidInput,
363                format!("unrecognized scheme in forwarding url: {url}"),
364            ))
365        }
366    };
367
368    // We have to write the proxy header _before_ tls termination
369    if let Some(header) = proxy_proto_header {
370        conn = Box::new(
371            proxy_proto::Stream::outgoing(conn, header)
372                .expect("re-serializing proxy header should always succeed"),
373        )
374    };
375
376    if backend_tls && !tunnel_tls {
377        let domain = pki_types::ServerName::try_from(host)
378            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
379            .to_owned();
380        conn = Box::new(
381            futures_rustls::TlsConnector::from(
382                tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
383            )
384            .connect(domain, conn.compat())
385            .await?
386            .compat(),
387        )
388    }
389
390    // TODO: header rewrites?
391
392    Ok(conn)
393}
394
395async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
396    let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
397    if let Ok(addr) = conn.peer_addr() {
398        Span::current().record("forward_addr", field::display(addr));
399    }
400    Ok(conn)
401}
402
403#[cfg(feature = "hyper")]
404fn serve_gateway_error(
405    err: impl fmt::Display + Send + 'static,
406    conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
407) -> JoinHandle<()> {
408    tokio::spawn(
409        async move {
410            let service = service_fn(move |_req| {
411                debug!("serving bad gateway error");
412                let mut resp = Response::new(format!("failed to dial backend: {err}"));
413                *resp.status_mut() = StatusCode::BAD_GATEWAY;
414                futures::future::ok::<_, Infallible>(resp)
415            });
416
417            let res = http1::Builder::new()
418                .keep_alive(false)
419                .serve_connection(conn, service)
420                .await;
421            debug!(?res, "connection closed");
422        }
423        .in_current_span(),
424    )
425}
426
427// https://github.com/rustls/rustls/blob/main/examples/src/bin/tlsclient-mio.rs#L334
428mod danger {
429    use futures_rustls::rustls;
430    use rustls::{
431        client::danger::HandshakeSignatureValid,
432        crypto::{
433            verify_tls12_signature,
434            verify_tls13_signature,
435            CryptoProvider,
436        },
437        DigitallySignedStruct,
438    };
439
440    use super::pki_types::{
441        CertificateDer,
442        ServerName,
443        UnixTime,
444    };
445
446    #[derive(Debug)]
447    pub struct NoCertificateVerification(CryptoProvider);
448
449    impl NoCertificateVerification {
450        pub fn new(provider: CryptoProvider) -> Self {
451            Self(provider)
452        }
453    }
454
455    impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
456        fn verify_server_cert(
457            &self,
458            _end_entity: &CertificateDer<'_>,
459            _intermediates: &[CertificateDer<'_>],
460            _server_name: &ServerName<'_>,
461            _ocsp: &[u8],
462            _now: UnixTime,
463        ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
464            Ok(rustls::client::danger::ServerCertVerified::assertion())
465        }
466
467        fn verify_tls12_signature(
468            &self,
469            message: &[u8],
470            cert: &CertificateDer<'_>,
471            dss: &DigitallySignedStruct,
472        ) -> Result<HandshakeSignatureValid, rustls::Error> {
473            verify_tls12_signature(
474                message,
475                cert,
476                dss,
477                &self.0.signature_verification_algorithms,
478            )
479        }
480
481        fn verify_tls13_signature(
482            &self,
483            message: &[u8],
484            cert: &CertificateDer<'_>,
485            dss: &DigitallySignedStruct,
486        ) -> Result<HandshakeSignatureValid, rustls::Error> {
487            verify_tls13_signature(
488                message,
489                cert,
490                dss,
491                &self.0.signature_verification_algorithms,
492            )
493        }
494
495        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
496            self.0.signature_verification_algorithms.supported_schemes()
497        }
498    }
499}