ngrok/
session.rs

1use std::{
2    collections::{
3        HashMap,
4        VecDeque,
5    },
6    env,
7    io,
8    sync::{
9        atomic::{
10            AtomicBool,
11            Ordering,
12        },
13        Arc,
14    },
15    time::Duration,
16};
17
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use bytes::Bytes;
21use futures::{
22    prelude::*,
23    FutureExt,
24};
25use futures_rustls::rustls::{
26    self,
27    pki_types,
28    RootCertStore,
29};
30use hyper_http_proxy::{
31    Intercept,
32    Proxy,
33    ProxyConnector,
34};
35use hyper_util::client::legacy::connect::HttpConnector;
36use muxado::heartbeat::HeartbeatConfig;
37pub use muxado::heartbeat::HeartbeatHandler;
38use once_cell::sync::{
39    Lazy,
40    OnceCell,
41};
42use regex::Regex;
43use rustls_pemfile::Item;
44use thiserror::Error;
45use tokio::{
46    io::{
47        AsyncRead,
48        AsyncWrite,
49    },
50    runtime::Handle,
51    sync::{
52        mpsc::{
53            channel,
54            Sender,
55        },
56        Mutex,
57        RwLock,
58    },
59};
60use tokio_retry::{
61    strategy::ExponentialBackoff,
62    RetryIf,
63};
64use tokio_util::compat::{
65    FuturesAsyncReadCompatExt,
66    TokioAsyncReadCompatExt,
67};
68use tower_service::Service;
69use tracing::{
70    debug,
71    warn,
72};
73use url::Url;
74
75pub use crate::internals::{
76    proto::{
77        CommandResp,
78        Restart,
79        Stop,
80        StopTunnel,
81        Update,
82    },
83    raw_session::{
84        CommandHandler,
85        RpcError,
86    },
87};
88use crate::{
89    config::{
90        HttpTunnelBuilder,
91        LabeledTunnelBuilder,
92        ProxyProto,
93        TcpTunnelBuilder,
94        TlsTunnelBuilder,
95        TunnelConfig,
96    },
97    conn::ConnInner,
98    internals::{
99        proto::{
100            AuthExtra,
101            BindExtra,
102            BindOpts,
103            Error,
104            HttpEndpoint,
105            SecretString,
106            TcpEndpoint,
107            TlsEndpoint,
108        },
109        raw_session::{
110            AcceptError as RawAcceptError,
111            CommandHandlers,
112            IncomingStreams,
113            RawSession,
114            RpcClient,
115            StartSessionError,
116            NOT_IMPLEMENTED,
117        },
118    },
119    tunnel::{
120        AcceptError,
121        TunnelInner,
122        TunnelInnerInfo,
123    },
124};
125
126pub(crate) const CERT_BYTES: &[u8] = include_bytes!("../assets/ngrok.ca.crt");
127const CLIENT_TYPE: &str = "ngrok-rust";
128const VERSION: &str = env!("CARGO_PKG_VERSION");
129
130#[derive(Clone)]
131struct BoundTunnel {
132    proto: String,
133    opts: Option<BindOpts>,
134    extra: BindExtra,
135    labels: HashMap<String, String>,
136    forwards_to: String,
137    forwards_proto: String,
138    verify_upstream_tls: bool,
139    tx: Sender<Result<ConnInner, AcceptError>>,
140}
141
142type TunnelConns = HashMap<String, BoundTunnel>;
143
144/// An ngrok session.
145///
146/// Encapsulates an established session with the ngrok service. Sessions recover
147/// from network failures by automatically reconnecting.
148#[derive(Clone)]
149pub struct Session {
150    // Note: this is implicitly used to detect when the session (and its
151    // tunnels) have been dropped in order to shut down the accept loop.
152    _dropref: awaitdrop::Ref,
153    inner: Arc<ArcSwap<SessionInner>>,
154}
155
156struct SessionInner {
157    runtime: Handle,
158    client: Mutex<RpcClient>,
159    closed: AtomicBool,
160    tunnels: RwLock<TunnelConns>,
161    builder: SessionBuilder,
162}
163
164/// A trait alias for types that can provide the base ngrok transport, i.e.
165/// bidirectional byte streams.
166///
167/// It is blanket-implemented for all types that satisfy its bounds. Most
168/// commonly, it will be a tls-wrapped tcp stream.
169pub trait IoStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
170impl<T> IoStream for T where T: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
171
172/// Trait for establishing the connection to the ngrok server.
173#[async_trait]
174pub trait Connector: Sync + Send + 'static {
175    /// The function used to establish the connection to the ngrok server.
176    ///
177    /// This is effectively `async |addr, tls_config, err| -> Result<IoStream>`.
178    ///
179    /// If it is being called due to a disconnect, the [AcceptError] argument will
180    /// be populated.
181    ///
182    /// If it returns `Err(ConnectError::Canceled)`, reconnecting will be canceled
183    /// and the session will be terminated. Note that this error will never be
184    /// returned from the [default_connect] function.
185    async fn connect(
186        &self,
187        host: String,
188        port: u16,
189        tls_config: Arc<rustls::ClientConfig>,
190        err: Option<AcceptError>,
191    ) -> Result<Box<dyn IoStream>, ConnectError>;
192}
193
194#[async_trait]
195impl<F, U> Connector for F
196where
197    F: Fn(String, u16, Arc<rustls::ClientConfig>, Option<AcceptError>) -> U + Send + Sync + 'static,
198    U: Future<Output = Result<Box<dyn IoStream>, ConnectError>> + Send,
199{
200    async fn connect(
201        &self,
202        host: String,
203        port: u16,
204        tls_config: Arc<rustls::ClientConfig>,
205        err: Option<AcceptError>,
206    ) -> Result<Box<dyn IoStream>, ConnectError> {
207        self(host, port, tls_config, err).await
208    }
209}
210
211/// The default ngrok connector.
212///
213/// Establishes a TCP connection to `addr`, and then performs a TLS handshake
214/// using the `tls_config`.
215///
216/// Discards any errors during reconnect, allowing attempts to recur
217/// indefinitely.
218pub async fn default_connect(
219    host: String,
220    port: u16,
221    tls_config: Arc<rustls::ClientConfig>,
222    _: Option<AcceptError>,
223) -> Result<Box<dyn IoStream>, ConnectError> {
224    let stream = tokio::net::TcpStream::connect(&(host.as_str(), port))
225        .await
226        .map_err(ConnectError::Tcp)?
227        .compat();
228
229    let domain = pki_types::ServerName::try_from(host)
230        .expect("host should have been validated by SessionBuilder::server_addr");
231
232    let tls_conn = futures_rustls::TlsConnector::from(tls_config)
233        .connect(domain, stream)
234        .await
235        .map_err(ConnectError::Tls)?;
236    Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
237}
238
239#[derive(Debug, Clone, Error)]
240#[error("unsupported proxy address: {0}")]
241/// An unsupported proxy address was provided.
242pub struct ProxyUnsupportedError(Url);
243
244fn connect_proxy(url: Url) -> Result<Arc<dyn Connector>, ProxyUnsupportedError> {
245    Ok(match url.scheme() {
246        "http" | "https" => Arc::new(connect_http_proxy(url)),
247        "socks5" => {
248            let host = url.host_str().unwrap_or_default();
249            let port = url.port().unwrap_or(1080);
250            Arc::new(connect_socks_proxy(format!("{host}:{port}")))
251        }
252        _ => return Err(ProxyUnsupportedError(url)),
253    })
254}
255
256fn connect_http_proxy(url: Url) -> impl Connector {
257    move |host: String, port, tls_config, _| {
258        let mut proxy = Proxy::new(
259            Intercept::All,
260            url.as_str().try_into().expect("urls should be valid uris"),
261        );
262        proxy.force_connect();
263        let mut connector = HttpConnector::new();
264        connector.enforce_http(false);
265        async move {
266            let mut connector = ProxyConnector::from_proxy(connector, proxy)
267                .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?;
268
269            let server_uri = format!("http://{host}:{port}")
270                .parse()
271                .expect("host should have been validated by SessionBuilder::server_addr");
272
273            let conn = connector
274                .call(server_uri)
275                .await
276                .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?;
277
278            let tls_conn = futures_rustls::TlsConnector::from(tls_config)
279                .connect(
280                    pki_types::ServerName::try_from(host)
281                        .expect("host should have been validated by SessionBuilder::server_addr"),
282                    hyper_util::rt::TokioIo::new(conn).compat(),
283                )
284                .await
285                .map_err(ConnectError::Tls)?;
286
287            Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
288        }
289    }
290}
291
292fn connect_socks_proxy(proxy_addr: String) -> impl Connector {
293    move |server_host: String, server_port, tls_config, _| {
294        let proxy_addr = proxy_addr.clone();
295        async move {
296            let conn = tokio_socks::tcp::Socks5Stream::connect(
297                proxy_addr.as_str(),
298                format!("{server_host}:{server_port}"),
299            )
300            .await
301            .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?
302            .compat();
303
304            let tls_conn = futures_rustls::TlsConnector::from(tls_config)
305                .connect(
306                    pki_types::ServerName::try_from(server_host)
307                        .expect("host should have been validated by SessionBuilder::server_addr"),
308                    conn,
309                )
310                .await
311                .map_err(ConnectError::Tls)?;
312
313            Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
314        }
315    }
316}
317
318/// The builder for an ngrok [Session].
319#[derive(Clone)]
320pub struct SessionBuilder {
321    // Consuming libraries and applications can add a client type and version on
322    // top of the "base" type and version declared by this library.
323    versions: VecDeque<(String, String, Option<String>)>,
324    authtoken: Option<SecretString>,
325    metadata: Option<String>,
326    heartbeat_interval: Option<i64>,
327    heartbeat_tolerance: Option<i64>,
328    heartbeat_handler: Option<Arc<dyn HeartbeatHandler>>,
329    server_host: String,
330    server_port: u16,
331    ca_cert: Option<bytes::Bytes>,
332    tls_config: Option<rustls::ClientConfig>,
333    connector: Arc<dyn Connector>,
334    handlers: CommandHandlers,
335    cookie: Option<SecretString>,
336    id: Option<String>,
337}
338
339/// Errors arising at [SessionBuilder::connect] time.
340#[derive(Error, Debug)]
341#[non_exhaustive]
342pub enum ConnectError {
343    /// An error occurred when establishing a TCP connection to the ngrok
344    /// server.
345    #[error("failed to establish tcp connection")]
346    Tcp(#[source] io::Error),
347    /// A TLS handshake error occurred.
348    ///
349    /// This is usually a certificate validation issue, or an attempt to connect
350    /// to something that doesn't actually speak TLS.
351    #[error("tls handshake error")]
352    Tls(#[source] io::Error),
353    /// An error occurred when starting the ngrok session.
354    ///
355    /// This might occur when there's a protocol mismatch interfering with the
356    /// heartbeat routine.
357    #[error("failed to start ngrok session")]
358    Start(#[source] StartSessionError),
359    /// An error occurred when attempting to authenticate.
360    #[error("authentication failure")]
361    Auth(#[source] RpcError),
362    /// An error occurred when rebinding tunnels during a reconnect
363    #[error("error rebinding tunnel after reconnect")]
364    Rebind(#[source] RpcError),
365    /// An error arising from a failure to connect through a proxy.
366    #[error("failed to connect through proxy")]
367    ProxyConnect(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
368    /// The (re)connect function gave up.
369    ///
370    /// This will never be returned by the default connect function, and is
371    /// instead used to cancel the reconnect loop.
372    #[error("the connect function gave up")]
373    Canceled,
374}
375
376impl Error for ConnectError {
377    fn error_code(&self) -> Option<&str> {
378        match self {
379            ConnectError::Auth(resp) | ConnectError::Rebind(resp) => resp.error_code(),
380            _ => None,
381        }
382    }
383    fn msg(&self) -> String {
384        match self {
385            ConnectError::Auth(resp) | ConnectError::Rebind(resp) => resp.msg(),
386            _ => format!("{self}"),
387        }
388    }
389}
390
391/// The builder specified an invalid heartbeat interval.
392///
393/// This is most likely caused a [Duration] that's outside of the [i64::MAX]
394/// nanosecond range.
395#[derive(Copy, Clone, Debug, Error)]
396#[error("invalid heartbeat interval: {0}")]
397pub struct InvalidHeartbeatInterval(u128);
398/// The builder specified an invalid heartbeat tolerance.
399///
400/// This is most likely caused a [Duration] that's outside of the [i64::MAX]
401/// nanosecond range.
402#[derive(Copy, Clone, Debug, Error)]
403#[error("invalid heartbeat tolerance: {0}")]
404pub struct InvalidHeartbeatTolerance(u128);
405
406/// The builder provided an invalid server address
407#[derive(Error, Debug, Clone)]
408#[error("invalid server address: {0}")]
409pub struct InvalidServerAddr(String);
410
411impl Default for SessionBuilder {
412    fn default() -> Self {
413        SessionBuilder {
414            versions: [(CLIENT_TYPE.to_string(), VERSION.to_string(), None)]
415                .into_iter()
416                .collect(),
417            authtoken: None,
418            metadata: None,
419            heartbeat_interval: None,
420            heartbeat_tolerance: None,
421            heartbeat_handler: None,
422            server_host: "connect.ngrok-agent.com".into(),
423            server_port: 443,
424            ca_cert: None,
425            tls_config: None,
426            connector: Arc::new(default_connect),
427            handlers: Default::default(),
428            cookie: None,
429            id: None,
430        }
431    }
432}
433
434fn sanitize_ua_string(s: impl AsRef<str>) -> String {
435    static UA_BANNED: OnceCell<Regex> = OnceCell::new();
436    UA_BANNED
437        .get_or_init(|| Regex::new("[^/!#$%&'*+-.^_`|~0-9a-zA-Z]").unwrap())
438        .replace_all(s.as_ref(), "#")
439        .replace('/', "-")
440}
441
442impl SessionBuilder {
443    /// Configures the session to authenticate with the provided authtoken. You
444    /// can [find your existing authtoken] or [create a new one] in the ngrok
445    /// dashboard.
446    ///
447    /// See the [authtoken parameter in the ngrok docs] for additional details.
448    ///
449    /// [find your existing authtoken]: https://dashboard.ngrok.com/get-started/your-authtoken
450    /// [create a new one]: https://dashboard.ngrok.com/tunnels/authtokens
451    /// [authtoken parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#authtoken
452    pub fn authtoken(&mut self, authtoken: impl Into<String>) -> &mut Self {
453        self.authtoken = Some(authtoken.into().into());
454        self
455    }
456    /// Shortcut for calling [SessionBuilder::authtoken] with the value of the
457    /// NGROK_AUTHTOKEN environment variable.
458    pub fn authtoken_from_env(&mut self) -> &mut Self {
459        self.authtoken = env::var("NGROK_AUTHTOKEN").ok().map(From::from);
460        self
461    }
462
463    /// Configures how often the session will send heartbeat messages to the ngrok
464    /// service to check session liveness.
465    ///
466    /// See the [heartbeat_interval parameter in the ngrok docs] for additional
467    /// details.
468    ///
469    /// [heartbeat_interval parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_interval
470    pub fn heartbeat_interval(
471        &mut self,
472        heartbeat_interval: Duration,
473    ) -> Result<&mut Self, InvalidHeartbeatInterval> {
474        let nanos = heartbeat_interval.as_nanos();
475        let nanos = i64::try_from(nanos).map_err(|_| InvalidHeartbeatInterval(nanos))?;
476        self.heartbeat_interval = Some(nanos);
477        Ok(self)
478    }
479
480    /// Configures the duration to wait for a response to a heartbeat before
481    /// assuming the session connection is dead and attempting to reconnect.
482    ///
483    /// See the [heartbeat_tolerance parameter in the ngrok docs] for additional
484    /// details.
485    ///
486    /// [heartbeat_tolerance parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#heartbeat_tolerance
487    pub fn heartbeat_tolerance(
488        &mut self,
489        heartbeat_tolerance: Duration,
490    ) -> Result<&mut Self, InvalidHeartbeatTolerance> {
491        let nanos = heartbeat_tolerance.as_nanos();
492        let nanos = i64::try_from(nanos).map_err(|_| InvalidHeartbeatTolerance(nanos))?;
493        self.heartbeat_tolerance = Some(nanos);
494        Ok(self)
495    }
496
497    /// Configures the opaque, machine-readable metadata string for this session.
498    /// Metadata is made available to you in the ngrok dashboard and the Agents API
499    /// resource. It is a useful way to allow you to uniquely identify sessions. We
500    /// suggest encoding the value in a structured format like JSON.
501    ///
502    /// See the [metdata parameter in the ngrok docs] for additional details.
503    ///
504    /// [metdata parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#metadata
505    pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
506        self.metadata = Some(metadata.into());
507        self
508    }
509
510    /// Configures the network address to dial to connect to the ngrok service.
511    /// Use this option only if you are connecting to a custom agent ingress.
512    ///
513    /// See the [server_addr parameter in the ngrok docs] for additional details.
514    ///
515    /// [server_addr parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#server_addr
516    pub fn server_addr(&mut self, addr: impl Into<String>) -> Result<&mut Self, InvalidServerAddr> {
517        let addr = addr.into();
518        let server_uri: Url = format!("http://{addr}")
519            .parse()
520            .map_err(|_| InvalidServerAddr(addr.clone()))?;
521
522        self.server_host = server_uri
523            .host_str()
524            .map(String::from)
525            .ok_or_else(|| InvalidServerAddr(addr.clone()))?;
526
527        pki_types::ServerName::try_from(self.server_host.as_str())
528            .map_err(|_| InvalidServerAddr(addr.clone()))?;
529
530        self.server_port = server_uri.port().unwrap_or(443);
531
532        Ok(self)
533    }
534
535    /// Sets the file path to a default certificate in PEM format to validate ngrok Session TLS connections.
536    /// Setting to "trusted" is the default, using the ngrok CA certificate.
537    /// Setting to "host" will verify using the certificates on the host operating system.
538    /// A client config set via tls_config after calling root_cas will override this value.
539    ///
540    /// Corresponds to the [root_cas parameter in the ngrok docs]
541    ///
542    /// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas
543    pub fn root_cas(&mut self, root_cas: impl Into<String>) -> Result<&mut Self, io::Error> {
544        match root_cas.into().clone().as_str() {
545            "trusted" => self.ca_cert = None,
546            "host" => self.tls_config = Some(host_certs_tls_config().map_err(|e| e.kind())?),
547            v => {
548                std::fs::read(v).map(|root_cas| self.ca_cert = Some(Bytes::from(root_cas)))?;
549            }
550        }
551        Ok(self)
552    }
553
554    /// Sets the default certificate in PEM format to validate ngrok Session TLS connections.
555    /// A client config set via tls_config will override this value.
556    ///
557    /// Roughly corresponds to the "path to a certificate PEM file" option in the
558    /// [root_cas parameter in the ngrok docs]
559    ///
560    /// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas
561    pub fn ca_cert(&mut self, ca_cert: Bytes) -> &mut Self {
562        self.ca_cert = Some(ca_cert);
563        self
564    }
565
566    /// Configures the TLS client used to connect to the ngrok service while
567    /// establishing the session. Use this option only if you are connecting through
568    /// a man-in-the-middle or deep packet inspection proxy. Passed to the
569    /// connect callback set with `SessionBuilder::connect`.
570    ///
571    /// Roughly corresponds to the [root_cas parameter in the ngrok docs], but allows
572    /// for deeper TLS configuration.
573    ///
574    /// [root_cas parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#root_cas
575    pub fn tls_config(&mut self, config: rustls::ClientConfig) -> &mut Self {
576        self.tls_config = Some(config);
577        self
578    }
579
580    /// Configures a function which is called to establish the connection to the
581    /// ngrok service. Use this option if you need to connect through an outbound
582    /// proxy. In the event of network disruptions, it will be called each time
583    /// the session reconnects.
584    pub fn connector(&mut self, connect: impl Connector) -> &mut Self {
585        self.connector = Arc::new(connect);
586        self
587    }
588
589    /// Configures the session to connect to ngrok through an outbound
590    /// HTTP or SOCKS5 proxy. This parameter is ignored if you override the connector
591    /// with [SessionBuilder::connector].
592    ///
593    /// See the [proxy url parameter in the ngrok docs] for additional details.
594    ///
595    /// [proxy url parameter in the ngrok docs]: https://ngrok.com/docs/ngrok-agent/config#proxy_url
596    pub fn proxy_url(&mut self, url: Url) -> Result<&mut Self, ProxyUnsupportedError> {
597        self.connector = connect_proxy(url)?;
598        Ok(self)
599    }
600
601    /// Configures a function which is called when the ngrok service requests that
602    /// this [Session] stops. Your application may choose to interpret this callback
603    /// as a request to terminate the [Session] or the entire process.
604    ///
605    /// Errors returned by this function will be visible to the ngrok dashboard or
606    /// API as the response to the Stop operation.
607    ///
608    /// Do not block inside this callback. It will cause the Dashboard or API
609    /// stop operation to time out. Do not call [std::process::exit] inside this
610    /// callback, it will also cause the operation to time out.
611    pub fn handle_stop_command(&mut self, handler: impl CommandHandler<Stop>) -> &mut Self {
612        self.handlers.on_stop = Some(Arc::new(handler));
613        self
614    }
615
616    /// Configures a function which is called when the ngrok service requests
617    /// that this [Session] updates. Your application may choose to interpret
618    /// this callback as a request to restart the [Session] or the entire
619    /// process.
620    ///
621    /// Errors returned by this function will be visible to the ngrok dashboard or
622    /// API as the response to the Restart operation.
623    ///
624    /// Do not block inside this callback. It will cause the Dashboard or API
625    /// stop operation to time out. Do not call [std::process::exit] inside this
626    /// callback, it will also cause the operation to time out.
627    pub fn handle_restart_command(&mut self, handler: impl CommandHandler<Restart>) -> &mut Self {
628        self.handlers.on_restart = Some(Arc::new(handler));
629        self
630    }
631
632    /// Configures a function which is called when the ngrok service requests
633    /// that this [Session] updates. Your application may choose to interpret
634    /// this callback as a request to update its configuration, itself, or to
635    /// invoke some other application-specific behavior.
636    ///
637    /// Errors returned by this function will be visible to the ngrok dashboard or
638    /// API as the response to the Restart operation.
639    ///
640    /// Do not block inside this callback. It will cause the Dashboard or API
641    /// stop operation to time out. Do not call [std::process::exit] inside this
642    /// callback, it will also cause the operation to time out.
643    pub fn handle_update_command(&mut self, handler: impl CommandHandler<Update>) -> &mut Self {
644        self.handlers.on_update = Some(Arc::new(handler));
645        self
646    }
647
648    /// Call the provided handler whenever a heartbeat response is received.
649    ///
650    /// If the handler returns an error, the heartbeat task will exit, resulting
651    /// in the session eventually dying as well.
652    pub fn handle_heartbeat(&mut self, callback: impl HeartbeatHandler) -> &mut Self {
653        self.heartbeat_handler = Some(Arc::new(callback));
654        self
655    }
656
657    /// Add client type and version information for a client application.
658    ///
659    /// This is a way for applications and library consumers of this crate
660    /// identify themselves.
661    ///
662    /// This will add a new entry to the `User-Agent` field in the "most significant"
663    /// (first) position. Comments must follow [RFC 7230] or a connection error may occur.
664    ///
665    /// [RFC 7230]: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
666    pub fn client_info(
667        &mut self,
668        client_type: impl Into<String>,
669        version: impl Into<String>,
670        comments: Option<impl Into<String>>,
671    ) -> &mut Self {
672        self.versions.push_front((
673            client_type.into(),
674            version.into(),
675            comments.map(|c| c.into()),
676        ));
677        self
678    }
679
680    /// Begins a new ngrok [Session] by connecting to the ngrok service.
681    /// `connect` blocks until the session is successfully established or fails with
682    /// an error.
683    pub async fn connect(&self) -> Result<Session, ConnectError> {
684        let (dropref, dropped) = awaitdrop::awaitdrop();
685        let (inner, mut incoming) = self.connect_inner(None).await?;
686
687        let rt = inner.runtime.clone();
688
689        let inner = Arc::new(ArcSwap::new(inner.into()));
690
691        let session = Session {
692            _dropref: dropref,
693            inner: inner.clone(),
694        };
695
696        // store the session for use with StopTunnel
697        incoming.session = Some(session.clone());
698
699        rt.spawn(future::select(
700            accept_incoming(incoming, inner).boxed(),
701            dropped.wait(),
702        ));
703
704        Ok(session)
705    }
706
707    pub(crate) fn get_or_create_tls_config(&self) -> rustls::ClientConfig {
708        // if the user has provided a custom TLS config, use that
709        if let Some(tls_config) = &self.tls_config {
710            return tls_config.clone();
711        }
712        // generate a default TLS config
713        let mut root_store = rustls::RootCertStore::empty();
714        let cert_pem = self.ca_cert.as_ref().map_or(CERT_BYTES, |it| it.as_ref());
715        let certs = rustls_pemfile::read_all(&mut io::Cursor::new(cert_pem))
716            .filter_map(|it| match it {
717                Ok(Item::X509Certificate(bs)) => Some(bs),
718                Err(e) => {
719                    warn!(error = ?e, "skipping certificate which failed to parse");
720                    None
721                }
722                Ok(_) => {
723                    warn!("skipping non-x509 certificate");
724                    None
725                }
726            })
727            .collect::<Vec<_>>();
728        root_store.add_parsable_certificates(certs);
729
730        rustls::ClientConfig::builder()
731            .with_root_certificates(root_store)
732            .with_no_client_auth()
733    }
734
735    async fn connect_inner(
736        &self,
737        err: impl Into<Option<AcceptError>>,
738    ) -> Result<(SessionInner, IncomingStreams), ConnectError> {
739        let conn = self
740            .connector
741            .connect(
742                self.server_host.clone(),
743                self.server_port,
744                Arc::new(self.get_or_create_tls_config()),
745                err.into(),
746            )
747            .await?;
748
749        let mut heartbeat_config = HeartbeatConfig::default();
750        if let Some(interval) = self.heartbeat_interval {
751            heartbeat_config.interval = Duration::from_nanos(interval as u64);
752        }
753        if let Some(tolerance) = self.heartbeat_tolerance {
754            heartbeat_config.tolerance = Duration::from_nanos(tolerance as u64);
755        }
756        heartbeat_config.handler = self.heartbeat_handler.clone();
757
758        // convert these while we have ownership
759        let heartbeat_interval = heartbeat_config.interval.as_nanos() as i64;
760        let heartbeat_tolerance = heartbeat_config.tolerance.as_nanos() as i64;
761
762        let mut raw = RawSession::start(conn, heartbeat_config, self.handlers.clone())
763            .await
764            .map_err(ConnectError::Start)?;
765
766        // list of possibilities: https://doc.rust-lang.org/std/env/consts/constant.OS.html
767        let os = match env::consts::OS {
768            "macos" => "darwin",
769            _ => env::consts::OS,
770        };
771
772        let user_agent = self
773            .versions
774            .iter()
775            .map(|(name, version, comments)| {
776                format!(
777                    "{}/{}{}",
778                    sanitize_ua_string(name),
779                    sanitize_ua_string(version),
780                    comments
781                        .as_ref()
782                        .map_or(String::new(), |f| format!(" ({f})"))
783                )
784            })
785            .collect::<Vec<_>>()
786            .join(" ");
787
788        let client_type = self.versions[0].0.clone();
789        let version = self.versions[0].1.clone();
790
791        let resp = raw
792            .auth(
793                self.id.as_deref().unwrap_or_default(),
794                AuthExtra {
795                    version,
796                    client_type,
797                    user_agent,
798                    auth_token: self.authtoken.clone().unwrap_or_default(),
799                    metadata: self.metadata.clone().unwrap_or_default(),
800                    os: os.into(),
801                    arch: std::env::consts::ARCH.into(),
802                    heartbeat_interval,
803                    heartbeat_tolerance,
804                    restart_unsupported_error: self
805                        .handlers
806                        .on_restart
807                        .is_none()
808                        .then_some(NOT_IMPLEMENTED.into())
809                        .or(Some("".into())),
810                    stop_unsupported_error: self
811                        .handlers
812                        .on_stop
813                        .is_none()
814                        .then_some(NOT_IMPLEMENTED.into())
815                        .or(Some("".into())),
816                    update_unsupported_error: self
817                        .handlers
818                        .on_update
819                        .is_none()
820                        .then_some(NOT_IMPLEMENTED.into())
821                        .or(Some("".into())),
822                    cookie: self.cookie.clone().unwrap_or_default(),
823                    ..Default::default()
824                },
825            )
826            .await
827            .map_err(ConnectError::Auth)?;
828
829        let (client, incoming) = raw.split();
830
831        let builder = SessionBuilder {
832            cookie: resp.extra.cookie,
833            id: resp.client_id.into(),
834            ..self.clone()
835        };
836
837        Ok((
838            SessionInner {
839                runtime: Handle::current(),
840                client: client.into(),
841                tunnels: Default::default(),
842                closed: Default::default(),
843                builder,
844            },
845            incoming,
846        ))
847    }
848}
849
850impl Session {
851    /// Create a new [SessionBuilder] to configure a new ngrok session.
852    pub fn builder() -> SessionBuilder {
853        SessionBuilder::default()
854    }
855
856    /// Start building a tunnel for an HTTP endpoint.
857    ///
858    /// https://ngrok.com/docs/http/
859    pub fn http_endpoint(&self) -> HttpTunnelBuilder {
860        self.clone().into()
861    }
862
863    /// Start building a tunnel for a TCP endpoint.
864    ///
865    /// https://ngrok.com/docs/tcp/
866    pub fn tcp_endpoint(&self) -> TcpTunnelBuilder {
867        self.clone().into()
868    }
869
870    /// Start building a tunnel for a TLS endpoint.
871    ///
872    /// https://ngrok.com/docs/tls/
873    pub fn tls_endpoint(&self) -> TlsTunnelBuilder {
874        self.clone().into()
875    }
876
877    /// Start building a labeled tunnel.
878    ///
879    /// https://ngrok.com/docs/network-edge/edges/#tunnel-group
880    pub fn labeled_tunnel(&self) -> LabeledTunnelBuilder {
881        self.clone().into()
882    }
883
884    /// Get the unique ID of this session.
885    pub fn id(&self) -> String {
886        self.inner
887            .load()
888            .builder
889            .id
890            .as_ref()
891            .expect("Session ID not set")
892            .clone()
893    }
894
895    /// Start a new tunnel in this session.
896    pub(crate) async fn start_tunnel<C>(&self, tunnel_cfg: C) -> Result<TunnelInner, RpcError>
897    where
898        C: TunnelConfig,
899    {
900        let inner = self.inner.load();
901        let mut client = inner.client.lock().await;
902
903        // let tunnelCfg: dyn TunnelConfig = TunnelConfig(opts);
904        let (tx, rx) = channel(64);
905
906        let proto = tunnel_cfg.proto();
907        let opts = tunnel_cfg.opts();
908        let mut extra = tunnel_cfg.extra();
909        let labels = tunnel_cfg.labels();
910        let forwards_to = tunnel_cfg.forwards_to();
911        let forwards_proto = tunnel_cfg.forwards_proto();
912        let verify_upstream_tls = tunnel_cfg.verify_upstream_tls();
913
914        // non-labeled tunnel
915        let (tunnel, bound) = if tunnel_cfg.proto() != "" {
916            let resp = client
917                .listen(
918                    &proto,
919                    opts.clone().unwrap(), // this is crate-defined, and must exist if proto is non-empty
920                    extra.clone(),
921                    "",
922                    &forwards_to,
923                    &forwards_proto,
924                )
925                .await?;
926
927            extra.token = resp.extra.token;
928            let info = TunnelInnerInfo {
929                id: resp.client_id,
930                proto: resp.proto.clone(),
931                url: resp.url,
932                labels: HashMap::new(),
933                forwards_to: tunnel_cfg.forwards_to(),
934                metadata: extra.metadata.clone(),
935            };
936
937            (
938                TunnelInner {
939                    info,
940                    session: self.clone(),
941                    incoming: rx.into(),
942                },
943                BoundTunnel {
944                    proto: resp.proto,
945                    opts: resp.bind_opts.into(),
946                    extra,
947                    labels,
948                    forwards_to,
949                    forwards_proto,
950                    verify_upstream_tls,
951                    tx,
952                },
953            )
954        } else {
955            // labeled tunnel
956            let resp = client
957                .listen_label(
958                    labels.clone(),
959                    &extra.metadata,
960                    &forwards_to,
961                    &forwards_proto,
962                )
963                .await?;
964
965            let info = TunnelInnerInfo {
966                id: resp.id,
967                proto: Default::default(),
968                url: Default::default(),
969                labels: tunnel_cfg.labels(),
970                forwards_to: tunnel_cfg.forwards_to(),
971                metadata: extra.metadata.clone(),
972            };
973
974            (
975                TunnelInner {
976                    info,
977                    session: self.clone(),
978                    incoming: rx.into(),
979                },
980                BoundTunnel {
981                    extra,
982                    proto: Default::default(),
983                    opts: Default::default(),
984                    forwards_to,
985                    forwards_proto,
986                    verify_upstream_tls,
987                    labels,
988                    tx,
989                },
990            )
991        };
992
993        let mut tunnels = inner.tunnels.write().await;
994        tunnels.insert(tunnel.info.id.clone(), bound);
995
996        Ok(tunnel)
997    }
998
999    /// Close a tunnel with an error from the remote.
1000    /// Skips the call to unlisten, since the remote has already rejected it.
1001    pub(crate) async fn close_tunnel_with_error(&self, id: impl AsRef<str>, err: AcceptError) {
1002        let id = id.as_ref();
1003        let inner = self.inner.load();
1004        if let Some(tun) = inner.tunnels.write().await.remove(id) {
1005            let _ = tun.tx.send(Err(err)).await;
1006        };
1007    }
1008
1009    /// Close a tunnel with the given ID.
1010    pub async fn close_tunnel(&self, id: impl AsRef<str>) -> Result<(), RpcError> {
1011        let id = id.as_ref();
1012        let inner = self.inner.load();
1013        inner.client.lock().await.unlisten(id).await?;
1014        inner.tunnels.write().await.remove(id);
1015        Ok(())
1016    }
1017
1018    pub(crate) fn runtime(&self) -> Handle {
1019        self.inner.load().runtime.clone()
1020    }
1021
1022    /// Close the ngrok session.
1023    pub async fn close(&mut self) -> Result<(), RpcError> {
1024        let inner = self.inner.load();
1025        let res = inner.client.lock().await.close().await;
1026        inner.closed.store(true, Ordering::SeqCst);
1027        res
1028    }
1029}
1030
1031pub(crate) fn host_certs_tls_config() -> Result<rustls::ClientConfig, &'static io::Error> {
1032    // The root certificate store, lazily loaded once.
1033    static ROOT_STORE: Lazy<Result<RootCertStore, io::Error>> = Lazy::new(|| {
1034        let der_certs = rustls_native_certs::load_native_certs()?
1035            .into_iter()
1036            .collect::<Vec<_>>();
1037        let mut root_store = RootCertStore::empty();
1038        root_store.add_parsable_certificates(der_certs);
1039        Ok(root_store)
1040    });
1041
1042    let root_store = ROOT_STORE.as_ref()?;
1043    Ok(rustls::ClientConfig::builder()
1044        .with_root_certificates(root_store.clone())
1045        .with_no_client_auth())
1046}
1047
1048async fn accept_one(
1049    incoming: &mut IncomingStreams,
1050    inner: &ArcSwap<SessionInner>,
1051) -> Result<(), AcceptError> {
1052    let conn = match incoming.accept().await {
1053        Ok(conn) => conn,
1054        // Assume if we got a muxado error, the session is borked. Break and
1055        // propagate the error to all of the tunnels out in the wild.
1056        Err(RawAcceptError::Transport(error)) => return Err(error.into()),
1057        // The other errors are either a bad header or an unrecognized
1058        // stream type. They're non-fatal, but could signal a protocol
1059        // mismatch.
1060        Err(error) => {
1061            warn!(?error, "protocol error when accepting tunnel connection");
1062            return Ok(());
1063        }
1064    };
1065    let id = conn.header.id.clone();
1066    let remote_addr = conn.header.client_addr.parse().unwrap_or_else(|error| {
1067        warn!(
1068            client_addr = conn.header.client_addr,
1069            %error,
1070            "invalid remote addr for tunnel connection",
1071        );
1072        "0.0.0.0:0".parse().unwrap()
1073    });
1074    let inner = inner.load();
1075    let guard = inner.tunnels.read().await;
1076    let res = if let Some(tun) = guard.get(&id) {
1077        let mut header = conn.header;
1078        let app_protocol = Some(tun.forwards_proto.to_string()).filter(|s| !s.is_empty());
1079        let verify_upstream_tls = tun.verify_upstream_tls;
1080        // Note: this is a bit of a hack. Normally, passthrough_tls is only
1081        // a thing on edge connections, but we're making sure it's set for
1082        // endpoint connections as well. In their case, we have to look at the
1083        // options used to bind the endpoint.
1084        if let Some(BindOpts::Tls(opts)) = &tun.opts {
1085            header.passthrough_tls = opts.tls_termination.is_none();
1086        }
1087        let proxy_proto = if let Some(
1088            BindOpts::Tls(TlsEndpoint { proxy_proto, .. })
1089            | BindOpts::Http(HttpEndpoint { proxy_proto, .. })
1090            | BindOpts::Tcp(TcpEndpoint { proxy_proto, .. }),
1091        ) = tun.opts
1092        {
1093            proxy_proto
1094        } else {
1095            ProxyProto::None
1096        };
1097        tun.tx
1098            .send(Ok(ConnInner {
1099                info: crate::conn::Info {
1100                    app_protocol,
1101                    verify_upstream_tls,
1102                    remote_addr,
1103                    header,
1104                    proxy_proto,
1105                },
1106                stream: conn.stream,
1107            }))
1108            .await
1109    } else {
1110        Ok(())
1111    };
1112    drop(guard);
1113    if res.is_err() {
1114        RwLock::write(&inner.tunnels).await.remove(&id);
1115    }
1116    Ok(())
1117}
1118
1119async fn try_reconnect(
1120    inner: Arc<ArcSwap<SessionInner>>,
1121    err: impl Into<Option<AcceptError>>,
1122) -> Result<IncomingStreams, ConnectError> {
1123    let old_inner = inner.load();
1124    if old_inner.closed.load(Ordering::SeqCst) {
1125        return Err(ConnectError::Canceled);
1126    }
1127    let (new_inner, new_incoming) = old_inner.builder.connect_inner(err).await?;
1128    let mut client = new_inner.client.lock().await;
1129    let mut new_tunnels = new_inner.tunnels.write().await;
1130    let old_tunnels = old_inner.tunnels.read().await;
1131
1132    for (id, tun) in old_tunnels.iter() {
1133        if !tun.proto.is_empty() {
1134            let resp = client
1135                .listen(
1136                    &tun.proto,
1137                    tun.opts.clone().unwrap(),
1138                    tun.extra.clone(),
1139                    id,
1140                    &tun.forwards_to,
1141                    &tun.forwards_proto,
1142                )
1143                .await
1144                .map_err(ConnectError::Rebind)?;
1145            debug!(?resp, %id, %tun.proto, ?tun.opts, ?tun.extra, %tun.forwards_to, "rebound tunnel");
1146            new_tunnels.insert(id.clone(), tun.clone());
1147        } else {
1148            let resp = client
1149                .listen_label(
1150                    tun.labels.clone(),
1151                    &tun.extra.metadata,
1152                    &tun.forwards_to,
1153                    &tun.forwards_proto,
1154                )
1155                .await
1156                .map_err(ConnectError::Rebind)?;
1157
1158            if !resp.id.is_empty() {
1159                new_tunnels.insert(resp.id, tun.clone());
1160            } else {
1161                new_tunnels.insert(id.clone(), tun.clone());
1162            }
1163        }
1164    }
1165
1166    drop(old_tunnels);
1167    drop(client);
1168    drop(new_tunnels);
1169    inner.store(new_inner.into());
1170
1171    Ok(new_incoming)
1172}
1173
1174async fn accept_incoming(mut incoming: IncomingStreams, inner: Arc<ArcSwap<SessionInner>>) {
1175    let error: AcceptError = loop {
1176        if let Err(error) = accept_one(&mut incoming, &inner).await {
1177            debug!(%error, "failed to accept stream, attempting reconnect");
1178            // This is gross, but should perform fine. Couple of notes:
1179            // * Mutex so that both the action and condition can share access to
1180            //   `error`. Realistically, the lock calls should be non-concurrent,
1181            //   but Rust can't prove that.
1182            // * Not setting the error in the action because then a a reference
1183            //   to a FnMut closure would escape via the returned Future, which is
1184            //   a no-no.
1185            let error = parking_lot::Mutex::new(Some(error));
1186            let reconnect = RetryIf::spawn(
1187                ExponentialBackoff::from_millis(50),
1188                || try_reconnect(inner.clone(), error.lock().clone()).map_err(Arc::new),
1189                |err: &Arc<ConnectError>| {
1190                    if let ConnectError::Canceled = **err {
1191                        false
1192                    } else {
1193                        *error.lock() = Some(AcceptError::Reconnect(err.clone()));
1194                        true
1195                    }
1196                },
1197            );
1198            incoming = match reconnect.await {
1199                Ok(incoming) => incoming,
1200                Err(error) => {
1201                    debug!(%error, "reconnect failed, giving up");
1202                    break AcceptError::Reconnect(error);
1203                }
1204            };
1205        }
1206    };
1207    for (_id, tun) in inner.load().tunnels.write().await.drain() {
1208        let _ = tun.tx.send(Err(error.clone())).await;
1209    }
1210}
1211
1212#[cfg(test)]
1213mod test {
1214    use super::*;
1215
1216    #[test]
1217    fn test_sanitize_ua() {
1218        assert_eq!(
1219            sanitize_ua_string("library/official/rust"),
1220            "library-official-rust"
1221        );
1222        assert_eq!(
1223            sanitize_ua_string("something@really☺weird"),
1224            "something#really#weird"
1225        );
1226    }
1227}