ngrok/
tunnel_ext.rs

1use std::{
2    collections::HashMap,
3    io,
4    sync::Arc,
5};
6#[cfg(feature = "hyper")]
7use std::{
8    convert::Infallible,
9    fmt,
10};
11
12use async_trait::async_trait;
13use bitflags::bitflags;
14use futures::stream::TryStreamExt;
15use futures_rustls::rustls::{
16    self,
17    pki_types,
18    ClientConfig,
19};
20#[cfg(feature = "hyper")]
21use hyper::{
22    server::conn::http1,
23    service::service_fn,
24    Response,
25    StatusCode,
26};
27use once_cell::sync::Lazy;
28use proxy_protocol::ProxyHeader;
29use tokio::{
30    io::copy_bidirectional,
31    net::TcpStream,
32    task::JoinHandle,
33};
34use tokio_util::compat::{
35    FuturesAsyncReadCompatExt,
36    TokioAsyncReadCompatExt,
37};
38#[cfg(feature = "hyper")]
39use tracing::debug;
40use tracing::{
41    field,
42    warn,
43    Instrument,
44    Span,
45};
46use url::Url;
47
48use crate::{
49    prelude::*,
50    proxy_proto,
51    session::IoStream,
52    EdgeConn,
53    EndpointConn,
54};
55
56#[allow(deprecated)]
57#[async_trait]
58impl<T> TunnelExt for T
59where
60    T: Tunnel + Send,
61    <T as Tunnel>::Conn: ConnExt,
62{
63    async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
64        forward_tunnel(self, url).await
65    }
66}
67
68/// Extension methods auto-implemented for all tunnel types
69#[async_trait]
70#[deprecated = "superceded by the `listen_and_forward` builder method"]
71pub trait TunnelExt: Tunnel + Send {
72    /// Forward incoming tunnel connections to the provided url based on its
73    /// scheme.
74    /// This currently supports http, https, tls, and tcp on all platforms, unix
75    /// sockets on unix platforms, and named pipes on Windows via the "pipe"
76    /// scheme.
77    ///
78    /// Unix socket URLs can be formatted as `unix://path/to/socket` or
79    /// `unix:path/to/socket` for relative paths or as `unix:///path/to/socket` or
80    /// `unix:/path/to/socket` for absolute paths.
81    ///
82    /// Windows named pipe URLs can be formatted as `pipe:mypipename` or
83    /// `pipe://host/mypipename`. If no host is provided, as with
84    /// `pipe:///mypipename` or `pipe:/mypipename`, the leading slash will be
85    /// preserved.
86    async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
87}
88
89pub(crate) trait ConnExt {
90    fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
91}
92
93#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
94pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
95where
96    T: Tunnel + 'static + ?Sized,
97    <T as Tunnel>::Conn: ConnExt,
98{
99    loop {
100        let tunnel_conn = if let Some(conn) = tun
101            .try_next()
102            .await
103            .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
104        {
105            conn
106        } else {
107            return Ok(());
108        };
109
110        tunnel_conn.forward_to(&url);
111    }
112}
113
114impl ConnExt for EdgeConn {
115    fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
116        let url = url.clone();
117        tokio::spawn(async move {
118            let mut upstream = match connect(
119                self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
120                self.inner.info.verify_upstream_tls,
121                self.inner.info.app_protocol.clone(),
122                None, // Edges don't support proxyproto (afaik)
123                &url,
124            )
125            .await
126            {
127                Ok(conn) => conn,
128                Err(error) => {
129                    #[cfg(feature = "hyper")]
130                    if self.edge_type() == EdgeType::Https {
131                        serve_gateway_error(format!("{error}"), self);
132                    }
133                    warn!(%error, "error connecting to upstream");
134                    return Err(error);
135                }
136            };
137
138            copy_bidirectional(&mut self, &mut upstream).await?;
139            Ok(())
140        })
141    }
142}
143
144impl ConnExt for EndpointConn {
145    fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
146        let url = url.clone();
147        tokio::spawn(async move {
148            let proxy_proto = self.inner.info.proxy_proto;
149            let proto_tls = self.proto() == "tls";
150            #[cfg(feature = "hyper")]
151            let proto_http = matches!(self.proto(), "http" | "https");
152            let passthrough_tls = self.inner.info.passthrough_tls();
153            let app_protocol = self.inner.info.app_protocol.clone();
154            let verify_upstream_tls = self.inner.info.verify_upstream_tls;
155
156            let (mut stream, proxy_header) = match proxy_proto {
157                ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
158                _ => {
159                    let mut stream = crate::proxy_proto::Stream::incoming(self);
160                    let header = stream
161                        .proxy_header()
162                        .await?
163                        .map_err(|e| {
164                            io::Error::new(
165                                io::ErrorKind::InvalidData,
166                                format!("invalid proxy-protocol header: {}", e),
167                            )
168                        })?
169                        .cloned();
170                    (stream, header)
171                }
172            };
173
174            let mut upstream = match connect(
175                proto_tls && passthrough_tls,
176                verify_upstream_tls,
177                app_protocol,
178                proxy_header,
179                &url,
180            )
181            .await
182            {
183                Ok(conn) => conn,
184                Err(error) => {
185                    #[cfg(feature = "hyper")]
186                    if proto_http {
187                        serve_gateway_error(format!("{error}"), stream);
188                    }
189                    warn!(%error, "error connecting to upstream");
190                    return Err(error);
191                }
192            };
193
194            copy_bidirectional(&mut stream, &mut upstream).await?;
195            Ok(())
196        })
197    }
198}
199
200bitflags! {
201    struct TlsFlags: u8 {
202        const FLAG_HTTP2       = 0b01;
203        const FLAG_verify_upstream_tls       = 0b10;
204        const FLAG_MAX     = Self::FLAG_HTTP2.bits()
205                           | Self::FLAG_verify_upstream_tls.bits();
206    }
207}
208
209static NO_CRYPTO_PROVIDER_ERROR: Lazy<io::Error> = Lazy::new(|| {
210    io::Error::new(
211        io::ErrorKind::NotFound,
212        "no default CryptoProvider installed",
213    )
214});
215
216fn tls_config(
217    app_protocol: Option<String>,
218    verify_upstream_tls: bool,
219) -> Result<Arc<ClientConfig>, &'static io::Error> {
220    // A hashmap of tls client configs for different configurations.
221    // There won't need to be a lot of variation among these, and we'll want to
222    // reuse them as much as we can, which is why we initialize them all once
223    // and then pull out the one we need.
224    // Disabling the lint because this is a local static that doesn't escape the
225    // enclosing context. It fine.
226    #[allow(clippy::type_complexity)]
227    static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
228        Lazy::new(|| {
229            std::ops::Range {
230                start: 0,
231                end: TlsFlags::FLAG_MAX.bits() + 1,
232            }
233            .map(|p| {
234                let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
235                let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
236                let mut config = crate::session::host_certs_tls_config()?;
237                if !verify_upstream_tls {
238                    let provider = rustls::crypto::CryptoProvider::get_default()
239                        .ok_or(&*NO_CRYPTO_PROVIDER_ERROR)?
240                        .as_ref()
241                        .clone();
242                    config.dangerous().set_certificate_verifier(Arc::new(
243                        danger::NoCertificateVerification::new(provider),
244                    ));
245                }
246
247                if http2 {
248                    config
249                        .alpn_protocols
250                        .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
251                }
252                Ok((p, Arc::new(config)))
253            })
254            .collect()
255        });
256
257    let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
258    let mut key = 0;
259    if Some("http2").eq(&app_protocol.as_deref()) {
260        key |= TlsFlags::FLAG_HTTP2.bits();
261    }
262    if verify_upstream_tls {
263        key |= TlsFlags::FLAG_verify_upstream_tls.bits();
264    }
265
266    Ok(configs
267        .get(&key)
268        .or_else(|| configs.get(&0))
269        .unwrap()
270        .clone())
271}
272
273// Establish the connection to forward the tunnel stream to.
274// Takes the tunnel and connection to make additional decisions on how to wrap
275// the forwarded connection, i.e. reordering tls termination and proxyproto.
276// Note: this additional wrapping logic currently unimplemented.
277async fn connect(
278    tunnel_tls: bool,
279    verify_upstream_tls: bool,
280    app_protocol: Option<String>,
281    proxy_proto_header: Option<ProxyHeader>,
282    url: &Url,
283) -> Result<Box<dyn IoStream>, io::Error> {
284    let host = url.host_str().unwrap_or("localhost");
285    let mut backend_tls: bool = false;
286    let mut conn: Box<dyn IoStream> = match url.scheme() {
287        "tcp" => {
288            let port = url.port().ok_or_else(|| {
289                io::Error::new(
290                    io::ErrorKind::InvalidInput,
291                    format!("missing port for tcp forwarding url {url}"),
292                )
293            })?;
294            let conn = connect_tcp(host, port).in_current_span().await?;
295            Box::new(conn)
296        }
297
298        "http" => {
299            let port = url.port().unwrap_or(80);
300            let conn = connect_tcp(host, port).in_current_span().await?;
301            Box::new(conn)
302        }
303
304        "https" | "tls" => {
305            let port = url.port().unwrap_or(443);
306            let conn = connect_tcp(host, port).in_current_span().await?;
307
308            backend_tls = true;
309            Box::new(conn)
310        }
311
312        #[cfg(not(target_os = "windows"))]
313        "unix" => {
314            use std::borrow::Cow;
315
316            use tokio::net::UnixStream;
317
318            let mut addr = Cow::Borrowed(url.path());
319            if let Some(host) = url.host_str() {
320                // note: if host exists, there should always be a leading / in
321                // the path, but we should consider it a relative path.
322                addr = Cow::Owned(format!("{host}{addr}"));
323            }
324            Box::new(UnixStream::connect(&*addr).await?)
325        }
326
327        #[cfg(target_os = "windows")]
328        "pipe" => {
329            use std::time::Duration;
330
331            use tokio::net::windows::named_pipe::ClientOptions;
332            use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
333
334            let mut pipe_name = url.path();
335            if url.host_str().is_some() {
336                pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
337            }
338            if pipe_name.is_empty() {
339                return Err(io::Error::new(
340                    io::ErrorKind::InvalidInput,
341                    format!("missing pipe name in forwarding url {url}"),
342                ));
343            }
344            let host = url
345                .host_str()
346                // Consider localhost to mean "." for the pipe name
347                .map(|h| if h == "localhost" { "." } else { h })
348                .unwrap_or(".");
349            // Finally, assemble the full name.
350            let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
351            // loop behavior copied from docs
352            // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
353            let local_conn = loop {
354                match ClientOptions::new().open(&addr) {
355                    Ok(client) => break client,
356                    Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
357                    Err(error) => return Err(error),
358                }
359
360                tokio::time::sleep(Duration::from_millis(50)).await;
361            };
362            Box::new(local_conn)
363        }
364        _ => {
365            return Err(io::Error::new(
366                io::ErrorKind::InvalidInput,
367                format!("unrecognized scheme in forwarding url: {url}"),
368            ))
369        }
370    };
371
372    // We have to write the proxy header _before_ tls termination
373    if let Some(header) = proxy_proto_header {
374        conn = Box::new(
375            proxy_proto::Stream::outgoing(conn, header)
376                .expect("re-serializing proxy header should always succeed"),
377        )
378    };
379
380    if backend_tls && !tunnel_tls {
381        let domain = pki_types::ServerName::try_from(host)
382            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
383            .to_owned();
384        conn = Box::new(
385            futures_rustls::TlsConnector::from(
386                tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
387            )
388            .connect(domain, conn.compat())
389            .await?
390            .compat(),
391        )
392    }
393
394    // TODO: header rewrites?
395
396    Ok(conn)
397}
398
399async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
400    let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
401    if let Ok(addr) = conn.peer_addr() {
402        Span::current().record("forward_addr", field::display(addr));
403    }
404    Ok(conn)
405}
406
407#[cfg(feature = "hyper")]
408fn serve_gateway_error(
409    err: impl fmt::Display + Send + 'static,
410    conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
411) -> JoinHandle<()> {
412    tokio::spawn(
413        async move {
414            let service = service_fn(move |_req| {
415                debug!("serving bad gateway error");
416                let mut resp = Response::new(format!("failed to dial backend: {err}"));
417                *resp.status_mut() = StatusCode::BAD_GATEWAY;
418                futures::future::ok::<_, Infallible>(resp)
419            });
420
421            let res = http1::Builder::new()
422                .keep_alive(false)
423                .serve_connection(conn, service)
424                .await;
425            debug!(?res, "connection closed");
426        }
427        .in_current_span(),
428    )
429}
430
431// https://github.com/rustls/rustls/blob/main/examples/src/bin/tlsclient-mio.rs#L334
432mod danger {
433    use futures_rustls::rustls;
434    use rustls::{
435        client::danger::HandshakeSignatureValid,
436        crypto::{
437            verify_tls12_signature,
438            verify_tls13_signature,
439            CryptoProvider,
440        },
441        DigitallySignedStruct,
442    };
443
444    use super::pki_types::{
445        CertificateDer,
446        ServerName,
447        UnixTime,
448    };
449
450    #[derive(Debug)]
451    pub struct NoCertificateVerification(CryptoProvider);
452
453    impl NoCertificateVerification {
454        pub fn new(provider: CryptoProvider) -> Self {
455            Self(provider)
456        }
457    }
458
459    impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
460        fn verify_server_cert(
461            &self,
462            _end_entity: &CertificateDer<'_>,
463            _intermediates: &[CertificateDer<'_>],
464            _server_name: &ServerName<'_>,
465            _ocsp: &[u8],
466            _now: UnixTime,
467        ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
468            Ok(rustls::client::danger::ServerCertVerified::assertion())
469        }
470
471        fn verify_tls12_signature(
472            &self,
473            message: &[u8],
474            cert: &CertificateDer<'_>,
475            dss: &DigitallySignedStruct,
476        ) -> Result<HandshakeSignatureValid, rustls::Error> {
477            verify_tls12_signature(
478                message,
479                cert,
480                dss,
481                &self.0.signature_verification_algorithms,
482            )
483        }
484
485        fn verify_tls13_signature(
486            &self,
487            message: &[u8],
488            cert: &CertificateDer<'_>,
489            dss: &DigitallySignedStruct,
490        ) -> Result<HandshakeSignatureValid, rustls::Error> {
491            verify_tls13_signature(
492                message,
493                cert,
494                dss,
495                &self.0.signature_verification_algorithms,
496            )
497        }
498
499        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
500            self.0.signature_verification_algorithms.supported_schemes()
501        }
502    }
503}