ngrok/
tunnel_ext.rs

1use std::{
2    collections::HashMap,
3    io,
4    sync::Arc,
5};
6#[cfg(feature = "hyper")]
7use std::{
8    convert::Infallible,
9    fmt,
10};
11
12use async_trait::async_trait;
13use bitflags::bitflags;
14use futures::stream::TryStreamExt;
15use futures_rustls::rustls::{
16    self,
17    pki_types,
18    ClientConfig,
19};
20#[cfg(feature = "hyper")]
21use hyper::{
22    server::conn::http1,
23    service::service_fn,
24    Response,
25    StatusCode,
26};
27use once_cell::sync::Lazy;
28use proxy_protocol::ProxyHeader;
29use tokio::{
30    io::copy_bidirectional,
31    net::TcpStream,
32    task::JoinHandle,
33};
34use tokio_util::compat::{
35    FuturesAsyncReadCompatExt,
36    TokioAsyncReadCompatExt,
37};
38#[cfg(feature = "hyper")]
39use tracing::debug;
40use tracing::{
41    field,
42    warn,
43    Instrument,
44    Span,
45};
46use url::Url;
47
48use crate::{
49    prelude::*,
50    proxy_proto,
51    session::IoStream,
52    EdgeConn,
53    EndpointConn,
54};
55
56#[allow(deprecated)]
57#[async_trait]
58impl<T> TunnelExt for T
59where
60    T: Tunnel + Send,
61    <T as Tunnel>::Conn: ConnExt,
62{
63    async fn forward(&mut self, url: Url) -> Result<(), io::Error> {
64        forward_tunnel(self, url).await
65    }
66}
67
68/// Extension methods auto-implemented for all tunnel types
69#[async_trait]
70#[deprecated = "superceded by the `listen_and_forward` builder method"]
71pub trait TunnelExt: Tunnel + Send {
72    /// Forward incoming tunnel connections to the provided url based on its
73    /// scheme.
74    /// This currently supports http, https, tls, and tcp on all platforms, unix
75    /// sockets on unix platforms, and named pipes on Windows via the "pipe"
76    /// scheme.
77    ///
78    /// Unix socket URLs can be formatted as `unix://path/to/socket` or
79    /// `unix:path/to/socket` for relative paths or as `unix:///path/to/socket` or
80    /// `unix:/path/to/socket` for absolute paths.
81    ///
82    /// Windows named pipe URLs can be formatted as `pipe:mypipename` or
83    /// `pipe://host/mypipename`. If no host is provided, as with
84    /// `pipe:///mypipename` or `pipe:/mypipename`, the leading slash will be
85    /// preserved.
86    async fn forward(&mut self, url: Url) -> Result<(), io::Error>;
87}
88
89pub(crate) trait ConnExt {
90    fn forward_to(self, url: &Url) -> JoinHandle<io::Result<()>>;
91}
92
93#[tracing::instrument(skip_all, fields(tunnel_id = tun.id(), url = %url))]
94pub(crate) async fn forward_tunnel<T>(tun: &mut T, url: Url) -> Result<(), io::Error>
95where
96    T: Tunnel + 'static + ?Sized,
97    <T as Tunnel>::Conn: ConnExt,
98{
99    loop {
100        let tunnel_conn = if let Some(conn) = tun
101            .try_next()
102            .await
103            .map_err(|err| io::Error::new(io::ErrorKind::NotConnected, err))?
104        {
105            conn
106        } else {
107            return Ok(());
108        };
109
110        tunnel_conn.forward_to(&url);
111    }
112}
113
114impl ConnExt for EdgeConn {
115    fn forward_to(mut self, url: &Url) -> JoinHandle<io::Result<()>> {
116        let url = url.clone();
117        tokio::spawn(async move {
118            let mut upstream = match connect(
119                self.edge_type() == EdgeType::Tls && self.passthrough_tls(),
120                self.inner.info.verify_upstream_tls,
121                self.inner.info.app_protocol.clone(),
122                None, // Edges don't support proxyproto (afaik)
123                &url,
124            )
125            .await
126            {
127                Ok(conn) => conn,
128                Err(error) => {
129                    #[cfg(feature = "hyper")]
130                    if self.edge_type() == EdgeType::Https {
131                        serve_gateway_error(format!("{error}"), self);
132                    }
133                    warn!(%error, "error connecting to upstream");
134                    return Err(error);
135                }
136            };
137
138            copy_bidirectional(&mut self, &mut upstream).await?;
139            Ok(())
140        })
141    }
142}
143
144impl ConnExt for EndpointConn {
145    fn forward_to(self, url: &Url) -> JoinHandle<Result<(), io::Error>> {
146        let url = url.clone();
147        tokio::spawn(async move {
148            let proxy_proto = self.inner.info.proxy_proto;
149            let proto_tls = self.proto() == "tls";
150            #[cfg(feature = "hyper")]
151            let proto_http = matches!(self.proto(), "http" | "https");
152            let passthrough_tls = self.inner.info.passthrough_tls();
153            let app_protocol = self.inner.info.app_protocol.clone();
154            let verify_upstream_tls = self.inner.info.verify_upstream_tls;
155
156            let (mut stream, proxy_header) = match proxy_proto {
157                ProxyProto::None => (crate::proxy_proto::Stream::disabled(self), None),
158                _ => {
159                    let mut stream = crate::proxy_proto::Stream::incoming(self);
160                    let header = stream
161                        .proxy_header()
162                        .await?
163                        .map_err(|e| {
164                            io::Error::new(
165                                io::ErrorKind::InvalidData,
166                                format!("invalid proxy-protocol header: {}", e),
167                            )
168                        })?
169                        .cloned();
170                    (stream, header)
171                }
172            };
173
174            let mut upstream = match connect(
175                proto_tls && passthrough_tls,
176                verify_upstream_tls,
177                app_protocol,
178                proxy_header,
179                &url,
180            )
181            .await
182            {
183                Ok(conn) => conn,
184                Err(error) => {
185                    #[cfg(feature = "hyper")]
186                    if proto_http {
187                        serve_gateway_error(format!("{error}"), stream);
188                    }
189                    warn!(%error, "error connecting to upstream");
190                    return Err(error);
191                }
192            };
193
194            copy_bidirectional(&mut stream, &mut upstream).await?;
195            Ok(())
196        })
197    }
198}
199
200bitflags! {
201    struct TlsFlags: u8 {
202        const FLAG_HTTP2       = 0b01;
203        const FLAG_verify_upstream_tls       = 0b10;
204        const FLAG_MAX     = Self::FLAG_HTTP2.bits()
205                           | Self::FLAG_verify_upstream_tls.bits();
206    }
207}
208
209fn tls_config(
210    app_protocol: Option<String>,
211    verify_upstream_tls: bool,
212) -> Result<Arc<ClientConfig>, &'static io::Error> {
213    // A hashmap of tls client configs for different configurations.
214    // There won't need to be a lot of variation among these, and we'll want to
215    // reuse them as much as we can, which is why we initialize them all once
216    // and then pull out the one we need.
217    // Disabling the lint because this is a local static that doesn't escape the
218    // enclosing context. It fine.
219    #[allow(clippy::type_complexity)]
220    static CONFIGS: Lazy<Result<HashMap<u8, Arc<ClientConfig>>, &'static io::Error>> =
221        Lazy::new(|| {
222            std::ops::Range {
223                start: 0,
224                end: TlsFlags::FLAG_MAX.bits() + 1,
225            }
226            .map(|p| {
227                let http2 = (p & TlsFlags::FLAG_HTTP2.bits()) != 0;
228                let verify_upstream_tls = (p & TlsFlags::FLAG_verify_upstream_tls.bits()) != 0;
229                let mut config = crate::session::host_certs_tls_config()?;
230                if !verify_upstream_tls {
231                    config.dangerous().set_certificate_verifier(Arc::new(
232                        danger::NoCertificateVerification::new(
233                            rustls::crypto::aws_lc_rs::default_provider(),
234                        ),
235                    ));
236                }
237
238                if http2 {
239                    config
240                        .alpn_protocols
241                        .extend(["h2", "http/1.1"].iter().map(|s| s.as_bytes().to_vec()));
242                }
243                Ok((p, Arc::new(config)))
244            })
245            .collect()
246        });
247
248    let configs: &HashMap<u8, Arc<ClientConfig>> = CONFIGS.as_ref().map_err(|e| *e)?;
249    let mut key = 0;
250    if Some("http2").eq(&app_protocol.as_deref()) {
251        key |= TlsFlags::FLAG_HTTP2.bits();
252    }
253    if verify_upstream_tls {
254        key |= TlsFlags::FLAG_verify_upstream_tls.bits();
255    }
256
257    Ok(configs
258        .get(&key)
259        .or_else(|| configs.get(&0))
260        .unwrap()
261        .clone())
262}
263
264// Establish the connection to forward the tunnel stream to.
265// Takes the tunnel and connection to make additional decisions on how to wrap
266// the forwarded connection, i.e. reordering tls termination and proxyproto.
267// Note: this additional wrapping logic currently unimplemented.
268async fn connect(
269    tunnel_tls: bool,
270    verify_upstream_tls: bool,
271    app_protocol: Option<String>,
272    proxy_proto_header: Option<ProxyHeader>,
273    url: &Url,
274) -> Result<Box<dyn IoStream>, io::Error> {
275    let host = url.host_str().unwrap_or("localhost");
276    let mut backend_tls: bool = false;
277    let mut conn: Box<dyn IoStream> = match url.scheme() {
278        "tcp" => {
279            let port = url.port().ok_or_else(|| {
280                io::Error::new(
281                    io::ErrorKind::InvalidInput,
282                    format!("missing port for tcp forwarding url {url}"),
283                )
284            })?;
285            let conn = connect_tcp(host, port).in_current_span().await?;
286            Box::new(conn)
287        }
288
289        "http" => {
290            let port = url.port().unwrap_or(80);
291            let conn = connect_tcp(host, port).in_current_span().await?;
292            Box::new(conn)
293        }
294
295        "https" | "tls" => {
296            let port = url.port().unwrap_or(443);
297            let conn = connect_tcp(host, port).in_current_span().await?;
298
299            backend_tls = true;
300            Box::new(conn)
301        }
302
303        #[cfg(not(target_os = "windows"))]
304        "unix" => {
305            use std::borrow::Cow;
306
307            use tokio::net::UnixStream;
308
309            let mut addr = Cow::Borrowed(url.path());
310            if let Some(host) = url.host_str() {
311                // note: if host exists, there should always be a leading / in
312                // the path, but we should consider it a relative path.
313                addr = Cow::Owned(format!("{host}{addr}"));
314            }
315            Box::new(UnixStream::connect(&*addr).await?)
316        }
317
318        #[cfg(target_os = "windows")]
319        "pipe" => {
320            use std::time::Duration;
321
322            use tokio::net::windows::named_pipe::ClientOptions;
323            use windows_sys::Win32::Foundation::ERROR_PIPE_BUSY;
324
325            let mut pipe_name = url.path();
326            if url.host_str().is_some() {
327                pipe_name = pipe_name.strip_prefix('/').unwrap_or(pipe_name);
328            }
329            if pipe_name.is_empty() {
330                return Err(io::Error::new(
331                    io::ErrorKind::InvalidInput,
332                    format!("missing pipe name in forwarding url {url}"),
333                ));
334            }
335            let host = url
336                .host_str()
337                // Consider localhost to mean "." for the pipe name
338                .map(|h| if h == "localhost" { "." } else { h })
339                .unwrap_or(".");
340            // Finally, assemble the full name.
341            let addr = format!("\\\\{host}\\pipe\\{pipe_name}");
342            // loop behavior copied from docs
343            // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
344            let local_conn = loop {
345                match ClientOptions::new().open(&addr) {
346                    Ok(client) => break client,
347                    Err(error) if error.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
348                    Err(error) => return Err(error),
349                }
350
351                tokio::time::sleep(Duration::from_millis(50)).await;
352            };
353            Box::new(local_conn)
354        }
355        _ => {
356            return Err(io::Error::new(
357                io::ErrorKind::InvalidInput,
358                format!("unrecognized scheme in forwarding url: {url}"),
359            ))
360        }
361    };
362
363    // We have to write the proxy header _before_ tls termination
364    if let Some(header) = proxy_proto_header {
365        conn = Box::new(
366            proxy_proto::Stream::outgoing(conn, header)
367                .expect("re-serializing proxy header should always succeed"),
368        )
369    };
370
371    if backend_tls && !tunnel_tls {
372        let domain = pki_types::ServerName::try_from(host)
373            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
374            .to_owned();
375        conn = Box::new(
376            futures_rustls::TlsConnector::from(
377                tls_config(app_protocol, verify_upstream_tls).map_err(|e| e.kind())?,
378            )
379            .connect(domain, conn.compat())
380            .await?
381            .compat(),
382        )
383    }
384
385    // TODO: header rewrites?
386
387    Ok(conn)
388}
389
390async fn connect_tcp(host: &str, port: u16) -> Result<TcpStream, io::Error> {
391    let conn = TcpStream::connect(&format!("{}:{}", host, port)).await?;
392    if let Ok(addr) = conn.peer_addr() {
393        Span::current().record("forward_addr", field::display(addr));
394    }
395    Ok(conn)
396}
397
398#[cfg(feature = "hyper")]
399fn serve_gateway_error(
400    err: impl fmt::Display + Send + 'static,
401    conn: impl hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
402) -> JoinHandle<()> {
403    tokio::spawn(
404        async move {
405            let service = service_fn(move |_req| {
406                debug!("serving bad gateway error");
407                let mut resp = Response::new(format!("failed to dial backend: {err}"));
408                *resp.status_mut() = StatusCode::BAD_GATEWAY;
409                futures::future::ok::<_, Infallible>(resp)
410            });
411
412            let res = http1::Builder::new()
413                .keep_alive(false)
414                .serve_connection(conn, service)
415                .await;
416            debug!(?res, "connection closed");
417        }
418        .in_current_span(),
419    )
420}
421
422// https://github.com/rustls/rustls/blob/main/examples/src/bin/tlsclient-mio.rs#L334
423mod danger {
424    use futures_rustls::rustls;
425    use rustls::{
426        client::danger::HandshakeSignatureValid,
427        crypto::{
428            verify_tls12_signature,
429            verify_tls13_signature,
430            CryptoProvider,
431        },
432        DigitallySignedStruct,
433    };
434
435    use super::pki_types::{
436        CertificateDer,
437        ServerName,
438        UnixTime,
439    };
440
441    #[derive(Debug)]
442    pub struct NoCertificateVerification(CryptoProvider);
443
444    impl NoCertificateVerification {
445        pub fn new(provider: CryptoProvider) -> Self {
446            Self(provider)
447        }
448    }
449
450    impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
451        fn verify_server_cert(
452            &self,
453            _end_entity: &CertificateDer<'_>,
454            _intermediates: &[CertificateDer<'_>],
455            _server_name: &ServerName<'_>,
456            _ocsp: &[u8],
457            _now: UnixTime,
458        ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
459            Ok(rustls::client::danger::ServerCertVerified::assertion())
460        }
461
462        fn verify_tls12_signature(
463            &self,
464            message: &[u8],
465            cert: &CertificateDer<'_>,
466            dss: &DigitallySignedStruct,
467        ) -> Result<HandshakeSignatureValid, rustls::Error> {
468            verify_tls12_signature(
469                message,
470                cert,
471                dss,
472                &self.0.signature_verification_algorithms,
473            )
474        }
475
476        fn verify_tls13_signature(
477            &self,
478            message: &[u8],
479            cert: &CertificateDer<'_>,
480            dss: &DigitallySignedStruct,
481        ) -> Result<HandshakeSignatureValid, rustls::Error> {
482            verify_tls13_signature(
483                message,
484                cert,
485                dss,
486                &self.0.signature_verification_algorithms,
487            )
488        }
489
490        fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
491            self.0.signature_verification_algorithms.supported_schemes()
492        }
493    }
494}