1use std::{
2 collections::{
3 HashMap,
4 VecDeque,
5 },
6 env,
7 io,
8 sync::{
9 atomic::{
10 AtomicBool,
11 Ordering,
12 },
13 Arc,
14 },
15 time::Duration,
16};
17
18use arc_swap::ArcSwap;
19use async_trait::async_trait;
20use bytes::Bytes;
21use futures::{
22 prelude::*,
23 FutureExt,
24};
25use futures_rustls::rustls::{
26 self,
27 pki_types,
28 RootCertStore,
29};
30use hyper_http_proxy::{
31 Intercept,
32 Proxy,
33 ProxyConnector,
34};
35use hyper_util::client::legacy::connect::HttpConnector;
36use muxado::heartbeat::HeartbeatConfig;
37pub use muxado::heartbeat::HeartbeatHandler;
38use once_cell::sync::{
39 Lazy,
40 OnceCell,
41};
42use regex::Regex;
43use rustls_pemfile::Item;
44use thiserror::Error;
45use tokio::{
46 io::{
47 AsyncRead,
48 AsyncWrite,
49 },
50 runtime::Handle,
51 sync::{
52 mpsc::{
53 channel,
54 Sender,
55 },
56 Mutex,
57 RwLock,
58 },
59};
60use tokio_retry::{
61 strategy::ExponentialBackoff,
62 RetryIf,
63};
64use tokio_util::compat::{
65 FuturesAsyncReadCompatExt,
66 TokioAsyncReadCompatExt,
67};
68use tower_service::Service;
69use tracing::{
70 debug,
71 warn,
72};
73use url::Url;
74
75pub use crate::internals::{
76 proto::{
77 CommandResp,
78 Restart,
79 Stop,
80 StopTunnel,
81 Update,
82 },
83 raw_session::{
84 CommandHandler,
85 RpcError,
86 },
87};
88use crate::{
89 config::{
90 HttpTunnelBuilder,
91 LabeledTunnelBuilder,
92 ProxyProto,
93 TcpTunnelBuilder,
94 TlsTunnelBuilder,
95 TunnelConfig,
96 },
97 conn::ConnInner,
98 internals::{
99 proto::{
100 AuthExtra,
101 BindExtra,
102 BindOpts,
103 Error,
104 HttpEndpoint,
105 SecretString,
106 TcpEndpoint,
107 TlsEndpoint,
108 },
109 raw_session::{
110 AcceptError as RawAcceptError,
111 CommandHandlers,
112 IncomingStreams,
113 RawSession,
114 RpcClient,
115 StartSessionError,
116 NOT_IMPLEMENTED,
117 },
118 },
119 tunnel::{
120 AcceptError,
121 TunnelInner,
122 TunnelInnerInfo,
123 },
124};
125
126pub(crate) const CERT_BYTES: &[u8] = include_bytes!("../assets/ngrok.ca.crt");
127const CLIENT_TYPE: &str = "ngrok-rust";
128const VERSION: &str = env!("CARGO_PKG_VERSION");
129
130#[derive(Clone)]
131struct BoundTunnel {
132 proto: String,
133 opts: Option<BindOpts>,
134 extra: BindExtra,
135 labels: HashMap<String, String>,
136 forwards_to: String,
137 forwards_proto: String,
138 verify_upstream_tls: bool,
139 tx: Sender<Result<ConnInner, AcceptError>>,
140}
141
142type TunnelConns = HashMap<String, BoundTunnel>;
143
144#[derive(Clone)]
149pub struct Session {
150 _dropref: awaitdrop::Ref,
153 inner: Arc<ArcSwap<SessionInner>>,
154}
155
156struct SessionInner {
157 runtime: Handle,
158 client: Mutex<RpcClient>,
159 closed: AtomicBool,
160 tunnels: RwLock<TunnelConns>,
161 builder: SessionBuilder,
162}
163
164pub trait IoStream: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
170impl<T> IoStream for T where T: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
171
172#[async_trait]
174pub trait Connector: Sync + Send + 'static {
175 async fn connect(
186 &self,
187 host: String,
188 port: u16,
189 tls_config: Arc<rustls::ClientConfig>,
190 err: Option<AcceptError>,
191 ) -> Result<Box<dyn IoStream>, ConnectError>;
192}
193
194#[async_trait]
195impl<F, U> Connector for F
196where
197 F: Fn(String, u16, Arc<rustls::ClientConfig>, Option<AcceptError>) -> U + Send + Sync + 'static,
198 U: Future<Output = Result<Box<dyn IoStream>, ConnectError>> + Send,
199{
200 async fn connect(
201 &self,
202 host: String,
203 port: u16,
204 tls_config: Arc<rustls::ClientConfig>,
205 err: Option<AcceptError>,
206 ) -> Result<Box<dyn IoStream>, ConnectError> {
207 self(host, port, tls_config, err).await
208 }
209}
210
211pub async fn default_connect(
219 host: String,
220 port: u16,
221 tls_config: Arc<rustls::ClientConfig>,
222 _: Option<AcceptError>,
223) -> Result<Box<dyn IoStream>, ConnectError> {
224 let stream = tokio::net::TcpStream::connect(&(host.as_str(), port))
225 .await
226 .map_err(ConnectError::Tcp)?
227 .compat();
228
229 let domain = pki_types::ServerName::try_from(host)
230 .expect("host should have been validated by SessionBuilder::server_addr");
231
232 let tls_conn = futures_rustls::TlsConnector::from(tls_config)
233 .connect(domain, stream)
234 .await
235 .map_err(ConnectError::Tls)?;
236 Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
237}
238
239#[derive(Debug, Clone, Error)]
240#[error("unsupported proxy address: {0}")]
241pub struct ProxyUnsupportedError(Url);
243
244fn connect_proxy(url: Url) -> Result<Arc<dyn Connector>, ProxyUnsupportedError> {
245 Ok(match url.scheme() {
246 "http" | "https" => Arc::new(connect_http_proxy(url)),
247 "socks5" => {
248 let host = url.host_str().unwrap_or_default();
249 let port = url.port().unwrap_or(1080);
250 Arc::new(connect_socks_proxy(format!("{host}:{port}")))
251 }
252 _ => return Err(ProxyUnsupportedError(url)),
253 })
254}
255
256fn connect_http_proxy(url: Url) -> impl Connector {
257 move |host: String, port, tls_config, _| {
258 let mut proxy = Proxy::new(
259 Intercept::All,
260 url.as_str().try_into().expect("urls should be valid uris"),
261 );
262 proxy.force_connect();
263 let mut connector = HttpConnector::new();
264 connector.enforce_http(false);
265 async move {
266 let mut connector = ProxyConnector::from_proxy(connector, proxy)
267 .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?;
268
269 let server_uri = format!("http://{host}:{port}")
270 .parse()
271 .expect("host should have been validated by SessionBuilder::server_addr");
272
273 let conn = connector
274 .call(server_uri)
275 .await
276 .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?;
277
278 let tls_conn = futures_rustls::TlsConnector::from(tls_config)
279 .connect(
280 pki_types::ServerName::try_from(host)
281 .expect("host should have been validated by SessionBuilder::server_addr"),
282 hyper_util::rt::TokioIo::new(conn).compat(),
283 )
284 .await
285 .map_err(ConnectError::Tls)?;
286
287 Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
288 }
289 }
290}
291
292fn connect_socks_proxy(proxy_addr: String) -> impl Connector {
293 move |server_host: String, server_port, tls_config, _| {
294 let proxy_addr = proxy_addr.clone();
295 async move {
296 let conn = tokio_socks::tcp::Socks5Stream::connect(
297 proxy_addr.as_str(),
298 format!("{server_host}:{server_port}"),
299 )
300 .await
301 .map_err(|e| ConnectError::ProxyConnect(Box::new(e)))?
302 .compat();
303
304 let tls_conn = futures_rustls::TlsConnector::from(tls_config)
305 .connect(
306 pki_types::ServerName::try_from(server_host)
307 .expect("host should have been validated by SessionBuilder::server_addr"),
308 conn,
309 )
310 .await
311 .map_err(ConnectError::Tls)?;
312
313 Ok(Box::new(tls_conn.compat()) as Box<dyn IoStream>)
314 }
315 }
316}
317
318#[derive(Clone)]
320pub struct SessionBuilder {
321 versions: VecDeque<(String, String, Option<String>)>,
324 authtoken: Option<SecretString>,
325 metadata: Option<String>,
326 heartbeat_interval: Option<i64>,
327 heartbeat_tolerance: Option<i64>,
328 heartbeat_handler: Option<Arc<dyn HeartbeatHandler>>,
329 server_host: String,
330 server_port: u16,
331 ca_cert: Option<bytes::Bytes>,
332 tls_config: Option<rustls::ClientConfig>,
333 connector: Arc<dyn Connector>,
334 handlers: CommandHandlers,
335 cookie: Option<SecretString>,
336 id: Option<String>,
337}
338
339#[derive(Error, Debug)]
341#[non_exhaustive]
342pub enum ConnectError {
343 #[error("failed to establish tcp connection")]
346 Tcp(#[source] io::Error),
347 #[error("tls handshake error")]
352 Tls(#[source] io::Error),
353 #[error("failed to start ngrok session")]
358 Start(#[source] StartSessionError),
359 #[error("authentication failure")]
361 Auth(#[source] RpcError),
362 #[error("error rebinding tunnel after reconnect")]
364 Rebind(#[source] RpcError),
365 #[error("failed to connect through proxy")]
367 ProxyConnect(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
368 #[error("the connect function gave up")]
373 Canceled,
374}
375
376impl Error for ConnectError {
377 fn error_code(&self) -> Option<&str> {
378 match self {
379 ConnectError::Auth(resp) | ConnectError::Rebind(resp) => resp.error_code(),
380 _ => None,
381 }
382 }
383 fn msg(&self) -> String {
384 match self {
385 ConnectError::Auth(resp) | ConnectError::Rebind(resp) => resp.msg(),
386 _ => format!("{self}"),
387 }
388 }
389}
390
391#[derive(Copy, Clone, Debug, Error)]
396#[error("invalid heartbeat interval: {0}")]
397pub struct InvalidHeartbeatInterval(u128);
398#[derive(Copy, Clone, Debug, Error)]
403#[error("invalid heartbeat tolerance: {0}")]
404pub struct InvalidHeartbeatTolerance(u128);
405
406#[derive(Error, Debug, Clone)]
408#[error("invalid server address: {0}")]
409pub struct InvalidServerAddr(String);
410
411impl Default for SessionBuilder {
412 fn default() -> Self {
413 SessionBuilder {
414 versions: [(CLIENT_TYPE.to_string(), VERSION.to_string(), None)]
415 .into_iter()
416 .collect(),
417 authtoken: None,
418 metadata: None,
419 heartbeat_interval: None,
420 heartbeat_tolerance: None,
421 heartbeat_handler: None,
422 server_host: "connect.ngrok-agent.com".into(),
423 server_port: 443,
424 ca_cert: None,
425 tls_config: None,
426 connector: Arc::new(default_connect),
427 handlers: Default::default(),
428 cookie: None,
429 id: None,
430 }
431 }
432}
433
434fn sanitize_ua_string(s: impl AsRef<str>) -> String {
435 static UA_BANNED: OnceCell<Regex> = OnceCell::new();
436 UA_BANNED
437 .get_or_init(|| Regex::new("[^/!#$%&'*+-.^_`|~0-9a-zA-Z]").unwrap())
438 .replace_all(s.as_ref(), "#")
439 .replace('/', "-")
440}
441
442impl SessionBuilder {
443 pub fn authtoken(&mut self, authtoken: impl Into<String>) -> &mut Self {
453 self.authtoken = Some(authtoken.into().into());
454 self
455 }
456 pub fn authtoken_from_env(&mut self) -> &mut Self {
459 self.authtoken = env::var("NGROK_AUTHTOKEN").ok().map(From::from);
460 self
461 }
462
463 pub fn heartbeat_interval(
471 &mut self,
472 heartbeat_interval: Duration,
473 ) -> Result<&mut Self, InvalidHeartbeatInterval> {
474 let nanos = heartbeat_interval.as_nanos();
475 let nanos = i64::try_from(nanos).map_err(|_| InvalidHeartbeatInterval(nanos))?;
476 self.heartbeat_interval = Some(nanos);
477 Ok(self)
478 }
479
480 pub fn heartbeat_tolerance(
488 &mut self,
489 heartbeat_tolerance: Duration,
490 ) -> Result<&mut Self, InvalidHeartbeatTolerance> {
491 let nanos = heartbeat_tolerance.as_nanos();
492 let nanos = i64::try_from(nanos).map_err(|_| InvalidHeartbeatTolerance(nanos))?;
493 self.heartbeat_tolerance = Some(nanos);
494 Ok(self)
495 }
496
497 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
506 self.metadata = Some(metadata.into());
507 self
508 }
509
510 pub fn server_addr(&mut self, addr: impl Into<String>) -> Result<&mut Self, InvalidServerAddr> {
517 let addr = addr.into();
518 let server_uri: Url = format!("http://{addr}")
519 .parse()
520 .map_err(|_| InvalidServerAddr(addr.clone()))?;
521
522 self.server_host = server_uri
523 .host_str()
524 .map(String::from)
525 .ok_or_else(|| InvalidServerAddr(addr.clone()))?;
526
527 pki_types::ServerName::try_from(self.server_host.as_str())
528 .map_err(|_| InvalidServerAddr(addr.clone()))?;
529
530 self.server_port = server_uri.port().unwrap_or(443);
531
532 Ok(self)
533 }
534
535 pub fn root_cas(&mut self, root_cas: impl Into<String>) -> Result<&mut Self, io::Error> {
544 match root_cas.into().clone().as_str() {
545 "trusted" => self.ca_cert = None,
546 "host" => self.tls_config = Some(host_certs_tls_config().map_err(|e| e.kind())?),
547 v => {
548 std::fs::read(v).map(|root_cas| self.ca_cert = Some(Bytes::from(root_cas)))?;
549 }
550 }
551 Ok(self)
552 }
553
554 pub fn ca_cert(&mut self, ca_cert: Bytes) -> &mut Self {
562 self.ca_cert = Some(ca_cert);
563 self
564 }
565
566 pub fn tls_config(&mut self, config: rustls::ClientConfig) -> &mut Self {
576 self.tls_config = Some(config);
577 self
578 }
579
580 pub fn connector(&mut self, connect: impl Connector) -> &mut Self {
585 self.connector = Arc::new(connect);
586 self
587 }
588
589 pub fn proxy_url(&mut self, url: Url) -> Result<&mut Self, ProxyUnsupportedError> {
597 self.connector = connect_proxy(url)?;
598 Ok(self)
599 }
600
601 pub fn handle_stop_command(&mut self, handler: impl CommandHandler<Stop>) -> &mut Self {
612 self.handlers.on_stop = Some(Arc::new(handler));
613 self
614 }
615
616 pub fn handle_restart_command(&mut self, handler: impl CommandHandler<Restart>) -> &mut Self {
628 self.handlers.on_restart = Some(Arc::new(handler));
629 self
630 }
631
632 pub fn handle_update_command(&mut self, handler: impl CommandHandler<Update>) -> &mut Self {
644 self.handlers.on_update = Some(Arc::new(handler));
645 self
646 }
647
648 pub fn handle_heartbeat(&mut self, callback: impl HeartbeatHandler) -> &mut Self {
653 self.heartbeat_handler = Some(Arc::new(callback));
654 self
655 }
656
657 pub fn client_info(
667 &mut self,
668 client_type: impl Into<String>,
669 version: impl Into<String>,
670 comments: Option<impl Into<String>>,
671 ) -> &mut Self {
672 self.versions.push_front((
673 client_type.into(),
674 version.into(),
675 comments.map(|c| c.into()),
676 ));
677 self
678 }
679
680 pub async fn connect(&self) -> Result<Session, ConnectError> {
684 let (dropref, dropped) = awaitdrop::awaitdrop();
685 let (inner, mut incoming) = self.connect_inner(None).await?;
686
687 let rt = inner.runtime.clone();
688
689 let inner = Arc::new(ArcSwap::new(inner.into()));
690
691 let session = Session {
692 _dropref: dropref,
693 inner: inner.clone(),
694 };
695
696 incoming.session = Some(session.clone());
698
699 rt.spawn(future::select(
700 accept_incoming(incoming, inner).boxed(),
701 dropped.wait(),
702 ));
703
704 Ok(session)
705 }
706
707 pub(crate) fn get_or_create_tls_config(&self) -> rustls::ClientConfig {
708 if let Some(tls_config) = &self.tls_config {
710 return tls_config.clone();
711 }
712 let mut root_store = rustls::RootCertStore::empty();
714 let cert_pem = self.ca_cert.as_ref().map_or(CERT_BYTES, |it| it.as_ref());
715 let certs = rustls_pemfile::read_all(&mut io::Cursor::new(cert_pem))
716 .filter_map(|it| match it {
717 Ok(Item::X509Certificate(bs)) => Some(bs),
718 Err(e) => {
719 warn!(error = ?e, "skipping certificate which failed to parse");
720 None
721 }
722 Ok(_) => {
723 warn!("skipping non-x509 certificate");
724 None
725 }
726 })
727 .collect::<Vec<_>>();
728 root_store.add_parsable_certificates(certs);
729
730 rustls::ClientConfig::builder()
731 .with_root_certificates(root_store)
732 .with_no_client_auth()
733 }
734
735 async fn connect_inner(
736 &self,
737 err: impl Into<Option<AcceptError>>,
738 ) -> Result<(SessionInner, IncomingStreams), ConnectError> {
739 let conn = self
740 .connector
741 .connect(
742 self.server_host.clone(),
743 self.server_port,
744 Arc::new(self.get_or_create_tls_config()),
745 err.into(),
746 )
747 .await?;
748
749 let mut heartbeat_config = HeartbeatConfig::default();
750 if let Some(interval) = self.heartbeat_interval {
751 heartbeat_config.interval = Duration::from_nanos(interval as u64);
752 }
753 if let Some(tolerance) = self.heartbeat_tolerance {
754 heartbeat_config.tolerance = Duration::from_nanos(tolerance as u64);
755 }
756 heartbeat_config.handler = self.heartbeat_handler.clone();
757
758 let heartbeat_interval = heartbeat_config.interval.as_nanos() as i64;
760 let heartbeat_tolerance = heartbeat_config.tolerance.as_nanos() as i64;
761
762 let mut raw = RawSession::start(conn, heartbeat_config, self.handlers.clone())
763 .await
764 .map_err(ConnectError::Start)?;
765
766 let os = match env::consts::OS {
768 "macos" => "darwin",
769 _ => env::consts::OS,
770 };
771
772 let user_agent = self
773 .versions
774 .iter()
775 .map(|(name, version, comments)| {
776 format!(
777 "{}/{}{}",
778 sanitize_ua_string(name),
779 sanitize_ua_string(version),
780 comments
781 .as_ref()
782 .map_or(String::new(), |f| format!(" ({f})"))
783 )
784 })
785 .collect::<Vec<_>>()
786 .join(" ");
787
788 let client_type = self.versions[0].0.clone();
789 let version = self.versions[0].1.clone();
790
791 let resp = raw
792 .auth(
793 self.id.as_deref().unwrap_or_default(),
794 AuthExtra {
795 version,
796 client_type,
797 user_agent,
798 auth_token: self.authtoken.clone().unwrap_or_default(),
799 metadata: self.metadata.clone().unwrap_or_default(),
800 os: os.into(),
801 arch: std::env::consts::ARCH.into(),
802 heartbeat_interval,
803 heartbeat_tolerance,
804 restart_unsupported_error: self
805 .handlers
806 .on_restart
807 .is_none()
808 .then_some(NOT_IMPLEMENTED.into())
809 .or(Some("".into())),
810 stop_unsupported_error: self
811 .handlers
812 .on_stop
813 .is_none()
814 .then_some(NOT_IMPLEMENTED.into())
815 .or(Some("".into())),
816 update_unsupported_error: self
817 .handlers
818 .on_update
819 .is_none()
820 .then_some(NOT_IMPLEMENTED.into())
821 .or(Some("".into())),
822 cookie: self.cookie.clone().unwrap_or_default(),
823 ..Default::default()
824 },
825 )
826 .await
827 .map_err(ConnectError::Auth)?;
828
829 let (client, incoming) = raw.split();
830
831 let builder = SessionBuilder {
832 cookie: resp.extra.cookie,
833 id: resp.client_id.into(),
834 ..self.clone()
835 };
836
837 Ok((
838 SessionInner {
839 runtime: Handle::current(),
840 client: client.into(),
841 tunnels: Default::default(),
842 closed: Default::default(),
843 builder,
844 },
845 incoming,
846 ))
847 }
848}
849
850impl Session {
851 pub fn builder() -> SessionBuilder {
853 SessionBuilder::default()
854 }
855
856 pub fn http_endpoint(&self) -> HttpTunnelBuilder {
860 self.clone().into()
861 }
862
863 pub fn tcp_endpoint(&self) -> TcpTunnelBuilder {
867 self.clone().into()
868 }
869
870 pub fn tls_endpoint(&self) -> TlsTunnelBuilder {
874 self.clone().into()
875 }
876
877 pub fn labeled_tunnel(&self) -> LabeledTunnelBuilder {
881 self.clone().into()
882 }
883
884 pub fn id(&self) -> String {
886 self.inner
887 .load()
888 .builder
889 .id
890 .as_ref()
891 .expect("Session ID not set")
892 .clone()
893 }
894
895 pub(crate) async fn start_tunnel<C>(&self, tunnel_cfg: C) -> Result<TunnelInner, RpcError>
897 where
898 C: TunnelConfig,
899 {
900 let inner = self.inner.load();
901 let mut client = inner.client.lock().await;
902
903 let (tx, rx) = channel(64);
905
906 let proto = tunnel_cfg.proto();
907 let opts = tunnel_cfg.opts();
908 let mut extra = tunnel_cfg.extra();
909 let labels = tunnel_cfg.labels();
910 let forwards_to = tunnel_cfg.forwards_to();
911 let forwards_proto = tunnel_cfg.forwards_proto();
912 let verify_upstream_tls = tunnel_cfg.verify_upstream_tls();
913
914 let (tunnel, bound) = if tunnel_cfg.proto() != "" {
916 let resp = client
917 .listen(
918 &proto,
919 opts.clone().unwrap(), extra.clone(),
921 "",
922 &forwards_to,
923 &forwards_proto,
924 )
925 .await?;
926
927 extra.token = resp.extra.token;
928 let info = TunnelInnerInfo {
929 id: resp.client_id,
930 proto: resp.proto.clone(),
931 url: resp.url,
932 labels: HashMap::new(),
933 forwards_to: tunnel_cfg.forwards_to(),
934 metadata: extra.metadata.clone(),
935 };
936
937 (
938 TunnelInner {
939 info,
940 session: self.clone(),
941 incoming: rx.into(),
942 },
943 BoundTunnel {
944 proto: resp.proto,
945 opts: resp.bind_opts.into(),
946 extra,
947 labels,
948 forwards_to,
949 forwards_proto,
950 verify_upstream_tls,
951 tx,
952 },
953 )
954 } else {
955 let resp = client
957 .listen_label(
958 labels.clone(),
959 &extra.metadata,
960 &forwards_to,
961 &forwards_proto,
962 )
963 .await?;
964
965 let info = TunnelInnerInfo {
966 id: resp.id,
967 proto: Default::default(),
968 url: Default::default(),
969 labels: tunnel_cfg.labels(),
970 forwards_to: tunnel_cfg.forwards_to(),
971 metadata: extra.metadata.clone(),
972 };
973
974 (
975 TunnelInner {
976 info,
977 session: self.clone(),
978 incoming: rx.into(),
979 },
980 BoundTunnel {
981 extra,
982 proto: Default::default(),
983 opts: Default::default(),
984 forwards_to,
985 forwards_proto,
986 verify_upstream_tls,
987 labels,
988 tx,
989 },
990 )
991 };
992
993 let mut tunnels = inner.tunnels.write().await;
994 tunnels.insert(tunnel.info.id.clone(), bound);
995
996 Ok(tunnel)
997 }
998
999 pub(crate) async fn close_tunnel_with_error(&self, id: impl AsRef<str>, err: AcceptError) {
1002 let id = id.as_ref();
1003 let inner = self.inner.load();
1004 if let Some(tun) = inner.tunnels.write().await.remove(id) {
1005 let _ = tun.tx.send(Err(err)).await;
1006 };
1007 }
1008
1009 pub async fn close_tunnel(&self, id: impl AsRef<str>) -> Result<(), RpcError> {
1011 let id = id.as_ref();
1012 let inner = self.inner.load();
1013 inner.client.lock().await.unlisten(id).await?;
1014 inner.tunnels.write().await.remove(id);
1015 Ok(())
1016 }
1017
1018 pub(crate) fn runtime(&self) -> Handle {
1019 self.inner.load().runtime.clone()
1020 }
1021
1022 pub async fn close(&mut self) -> Result<(), RpcError> {
1024 let inner = self.inner.load();
1025 let res = inner.client.lock().await.close().await;
1026 inner.closed.store(true, Ordering::SeqCst);
1027 res
1028 }
1029}
1030
1031pub(crate) fn host_certs_tls_config() -> Result<rustls::ClientConfig, &'static io::Error> {
1032 static ROOT_STORE: Lazy<Result<RootCertStore, io::Error>> = Lazy::new(|| {
1034 let der_certs = rustls_native_certs::load_native_certs()?
1035 .into_iter()
1036 .collect::<Vec<_>>();
1037 let mut root_store = RootCertStore::empty();
1038 root_store.add_parsable_certificates(der_certs);
1039 Ok(root_store)
1040 });
1041
1042 let root_store = ROOT_STORE.as_ref()?;
1043 Ok(rustls::ClientConfig::builder()
1044 .with_root_certificates(root_store.clone())
1045 .with_no_client_auth())
1046}
1047
1048async fn accept_one(
1049 incoming: &mut IncomingStreams,
1050 inner: &ArcSwap<SessionInner>,
1051) -> Result<(), AcceptError> {
1052 let conn = match incoming.accept().await {
1053 Ok(conn) => conn,
1054 Err(RawAcceptError::Transport(error)) => return Err(error.into()),
1057 Err(error) => {
1061 warn!(?error, "protocol error when accepting tunnel connection");
1062 return Ok(());
1063 }
1064 };
1065 let id = conn.header.id.clone();
1066 let remote_addr = conn.header.client_addr.parse().unwrap_or_else(|error| {
1067 warn!(
1068 client_addr = conn.header.client_addr,
1069 %error,
1070 "invalid remote addr for tunnel connection",
1071 );
1072 "0.0.0.0:0".parse().unwrap()
1073 });
1074 let inner = inner.load();
1075 let guard = inner.tunnels.read().await;
1076 let res = if let Some(tun) = guard.get(&id) {
1077 let mut header = conn.header;
1078 let app_protocol = Some(tun.forwards_proto.to_string()).filter(|s| !s.is_empty());
1079 let verify_upstream_tls = tun.verify_upstream_tls;
1080 if let Some(BindOpts::Tls(opts)) = &tun.opts {
1085 header.passthrough_tls = opts.tls_termination.is_none();
1086 }
1087 let proxy_proto = if let Some(
1088 BindOpts::Tls(TlsEndpoint { proxy_proto, .. })
1089 | BindOpts::Http(HttpEndpoint { proxy_proto, .. })
1090 | BindOpts::Tcp(TcpEndpoint { proxy_proto, .. }),
1091 ) = tun.opts
1092 {
1093 proxy_proto
1094 } else {
1095 ProxyProto::None
1096 };
1097 tun.tx
1098 .send(Ok(ConnInner {
1099 info: crate::conn::Info {
1100 app_protocol,
1101 verify_upstream_tls,
1102 remote_addr,
1103 header,
1104 proxy_proto,
1105 },
1106 stream: conn.stream,
1107 }))
1108 .await
1109 } else {
1110 Ok(())
1111 };
1112 drop(guard);
1113 if res.is_err() {
1114 RwLock::write(&inner.tunnels).await.remove(&id);
1115 }
1116 Ok(())
1117}
1118
1119async fn try_reconnect(
1120 inner: Arc<ArcSwap<SessionInner>>,
1121 err: impl Into<Option<AcceptError>>,
1122) -> Result<IncomingStreams, ConnectError> {
1123 let old_inner = inner.load();
1124 if old_inner.closed.load(Ordering::SeqCst) {
1125 return Err(ConnectError::Canceled);
1126 }
1127 let (new_inner, new_incoming) = old_inner.builder.connect_inner(err).await?;
1128 let mut client = new_inner.client.lock().await;
1129 let mut new_tunnels = new_inner.tunnels.write().await;
1130 let old_tunnels = old_inner.tunnels.read().await;
1131
1132 for (id, tun) in old_tunnels.iter() {
1133 if !tun.proto.is_empty() {
1134 let resp = client
1135 .listen(
1136 &tun.proto,
1137 tun.opts.clone().unwrap(),
1138 tun.extra.clone(),
1139 id,
1140 &tun.forwards_to,
1141 &tun.forwards_proto,
1142 )
1143 .await
1144 .map_err(ConnectError::Rebind)?;
1145 debug!(?resp, %id, %tun.proto, ?tun.opts, ?tun.extra, %tun.forwards_to, "rebound tunnel");
1146 new_tunnels.insert(id.clone(), tun.clone());
1147 } else {
1148 let resp = client
1149 .listen_label(
1150 tun.labels.clone(),
1151 &tun.extra.metadata,
1152 &tun.forwards_to,
1153 &tun.forwards_proto,
1154 )
1155 .await
1156 .map_err(ConnectError::Rebind)?;
1157
1158 if !resp.id.is_empty() {
1159 new_tunnels.insert(resp.id, tun.clone());
1160 } else {
1161 new_tunnels.insert(id.clone(), tun.clone());
1162 }
1163 }
1164 }
1165
1166 drop(old_tunnels);
1167 drop(client);
1168 drop(new_tunnels);
1169 inner.store(new_inner.into());
1170
1171 Ok(new_incoming)
1172}
1173
1174async fn accept_incoming(mut incoming: IncomingStreams, inner: Arc<ArcSwap<SessionInner>>) {
1175 let error: AcceptError = loop {
1176 if let Err(error) = accept_one(&mut incoming, &inner).await {
1177 debug!(%error, "failed to accept stream, attempting reconnect");
1178 let error = parking_lot::Mutex::new(Some(error));
1186 let reconnect = RetryIf::spawn(
1187 ExponentialBackoff::from_millis(50),
1188 || try_reconnect(inner.clone(), error.lock().clone()).map_err(Arc::new),
1189 |err: &Arc<ConnectError>| {
1190 if let ConnectError::Canceled = **err {
1191 false
1192 } else {
1193 *error.lock() = Some(AcceptError::Reconnect(err.clone()));
1194 true
1195 }
1196 },
1197 );
1198 incoming = match reconnect.await {
1199 Ok(incoming) => incoming,
1200 Err(error) => {
1201 debug!(%error, "reconnect failed, giving up");
1202 break AcceptError::Reconnect(error);
1203 }
1204 };
1205 }
1206 };
1207 for (_id, tun) in inner.load().tunnels.write().await.drain() {
1208 let _ = tun.tx.send(Err(error.clone())).await;
1209 }
1210}
1211
1212#[cfg(test)]
1213mod test {
1214 use super::*;
1215
1216 #[test]
1217 fn test_sanitize_ua() {
1218 assert_eq!(
1219 sanitize_ua_string("library/official/rust"),
1220 "library-official-rust"
1221 );
1222 assert_eq!(
1223 sanitize_ua_string("something@really☺weird"),
1224 "something#really#weird"
1225 );
1226 }
1227}