1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 convert::From,
5 str::FromStr,
6};
7
8use bytes::Bytes;
9use thiserror::Error;
10use url::Url;
11
12use super::{
13 common::ProxyProto,
14 Policy,
15};
16#[allow(unused_imports)]
18use crate::config::{
19 ForwarderBuilder,
20 TunnelBuilder,
21};
22use crate::{
23 config::{
24 common::{
25 default_forwards_to,
26 Binding,
27 CommonOpts,
28 TunnelConfig,
29 },
30 headers::Headers,
31 oauth::OauthOptions,
32 oidc::OidcOptions,
33 webhook_verification::WebhookVerification,
34 },
35 internals::proto::{
36 BasicAuth,
37 BasicAuthCredential,
38 BindExtra,
39 BindOpts,
40 CircuitBreaker,
41 Compression,
42 HttpEndpoint,
43 UserAgentFilter,
44 WebsocketTcpConverter,
45 },
46 tunnel::HttpTunnel,
47 Session,
48};
49
50#[derive(Debug, Clone, Error)]
52#[error("invalid scheme string: {}", .0)]
53pub struct InvalidSchemeString(String);
54
55#[derive(Clone, Default, Eq, PartialEq)]
59pub enum Scheme {
60 HTTP,
62 #[default]
64 HTTPS,
65}
66
67impl FromStr for Scheme {
68 type Err = InvalidSchemeString;
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 use Scheme::*;
71 Ok(match s.to_uppercase().as_str() {
72 "HTTP" => HTTP,
73 "HTTPS" => HTTPS,
74 _ => return Err(InvalidSchemeString(s.into())),
75 })
76 }
77}
78
79#[derive(Clone, Default)]
81pub(crate) struct UaFilter {
82 pub(crate) allow: Vec<String>,
84 pub(crate) deny: Vec<String>,
87}
88
89impl UaFilter {
90 pub(crate) fn allow(&mut self, allow: impl Into<String>) {
91 self.allow.push(allow.into());
92 }
93 pub(crate) fn deny(&mut self, deny: impl Into<String>) {
94 self.deny.push(deny.into());
95 }
96}
97
98impl From<UaFilter> for UserAgentFilter {
99 fn from(ua: UaFilter) -> Self {
100 UserAgentFilter {
101 allow: ua.allow,
102 deny: ua.deny,
103 }
104 }
105}
106
107#[derive(Default, Clone)]
109struct HttpOptions {
110 pub(crate) common_opts: CommonOpts,
111 pub(crate) scheme: Scheme,
112 pub(crate) domain: Option<String>,
113 pub(crate) mutual_tlsca: Vec<bytes::Bytes>,
114 pub(crate) compression: bool,
115 pub(crate) websocket_tcp_conversion: bool,
116 pub(crate) circuit_breaker: f64,
117 pub(crate) request_headers: Headers,
118 pub(crate) response_headers: Headers,
119 pub(crate) rewrite_host: bool,
120 pub(crate) basic_auth: Vec<(String, String)>,
121 pub(crate) oauth: Option<OauthOptions>,
122 pub(crate) oidc: Option<OidcOptions>,
123 pub(crate) webhook_verification: Option<WebhookVerification>,
124 pub(crate) user_agent_filter: UaFilter,
126 pub(crate) bindings: Vec<String>,
127}
128
129impl HttpOptions {
130 fn user_agent_filter(&self) -> Option<UserAgentFilter> {
131 (!self.user_agent_filter.allow.is_empty() || !self.user_agent_filter.deny.is_empty())
132 .then_some(self.user_agent_filter.clone().into())
133 }
134}
135
136impl TunnelConfig for HttpOptions {
137 fn forwards_to(&self) -> String {
138 self.common_opts
139 .forwards_to
140 .clone()
141 .unwrap_or(default_forwards_to().into())
142 }
143
144 fn forwards_proto(&self) -> String {
145 self.common_opts.forwards_proto.clone().unwrap_or_default()
146 }
147
148 fn verify_upstream_tls(&self) -> bool {
149 self.common_opts.verify_upstream_tls()
150 }
151
152 fn extra(&self) -> BindExtra {
153 BindExtra {
154 token: Default::default(),
155 ip_policy_ref: Default::default(),
156 metadata: self.common_opts.metadata.clone().unwrap_or_default(),
157 bindings: self.bindings.clone(),
158 pooling_enabled: self.common_opts.pooling_enabled.unwrap_or(false),
159 }
160 }
161 fn proto(&self) -> String {
162 if self.scheme == Scheme::HTTP {
163 return "http".into();
164 }
165 "https".into()
166 }
167 fn opts(&self) -> Option<BindOpts> {
168 let http_endpoint = HttpEndpoint {
169 proxy_proto: self.common_opts.proxy_proto,
170 hostname: self.domain.clone().unwrap_or_default(),
171 compression: self.compression.then_some(Compression {}),
172 circuit_breaker: (self.circuit_breaker != 0f64).then_some(CircuitBreaker {
173 error_threshold: self.circuit_breaker,
174 }),
175 ip_restriction: self.common_opts.ip_restriction(),
176 basic_auth: (!self.basic_auth.is_empty()).then_some(self.basic_auth.as_slice().into()),
177 oauth: self.oauth.clone().map(From::from),
178 oidc: self.oidc.clone().map(From::from),
179 webhook_verification: self.webhook_verification.clone().map(From::from),
180 mutual_tls_ca: (!self.mutual_tlsca.is_empty())
181 .then_some(self.mutual_tlsca.as_slice().into()),
182 request_headers: self
183 .request_headers
184 .has_entries()
185 .then_some(self.request_headers.clone().into()),
186 response_headers: self
187 .response_headers
188 .has_entries()
189 .then_some(self.response_headers.clone().into()),
190 websocket_tcp_converter: self
191 .websocket_tcp_conversion
192 .then_some(WebsocketTcpConverter {}),
193 user_agent_filter: self.user_agent_filter(),
194 traffic_policy: if self.common_opts.traffic_policy.is_some() {
195 self.common_opts.traffic_policy.clone().map(From::from)
196 } else if self.common_opts.policy.is_some() {
197 self.common_opts.policy.clone().map(From::from)
198 } else {
199 None
200 },
201 ..Default::default()
202 };
203
204 Some(BindOpts::Http(http_endpoint))
205 }
206 fn labels(&self) -> HashMap<String, String> {
207 HashMap::new()
208 }
209}
210
211impl From<&[(String, String)]> for BasicAuth {
213 fn from(v: &[(String, String)]) -> Self {
214 BasicAuth {
215 credentials: v.iter().cloned().map(From::from).collect(),
216 }
217 }
218}
219
220impl From<(String, String)> for BasicAuthCredential {
222 fn from(b: (String, String)) -> Self {
223 BasicAuthCredential {
224 username: b.0,
225 cleartext_password: b.1,
226 hashed_password: vec![], }
228 }
229}
230
231impl_builder! {
232 HttpTunnelBuilder, HttpOptions, HttpTunnel, endpoint
236}
237
238impl HttpTunnelBuilder {
239 pub fn allow_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
243 self.options.common_opts.cidr_restrictions.allow(cidr);
244 self
245 }
246 pub fn deny_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
250 self.options.common_opts.cidr_restrictions.deny(cidr);
251 self
252 }
253 pub fn proxy_proto(&mut self, proxy_proto: ProxyProto) -> &mut Self {
255 self.options.common_opts.proxy_proto = proxy_proto;
256 self
257 }
258 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
262 self.options.common_opts.metadata = Some(metadata.into());
263 self
264 }
265
266 pub fn binding(&mut self, binding: impl Into<String>) -> &mut Self {
296 if !self.options.bindings.is_empty() {
297 panic!("binding() can only be called once");
298 }
299 let binding_str = binding.into();
300 if let Err(e) = Binding::validate(&binding_str) {
301 panic!("{}", e);
302 }
303 self.options.bindings.push(binding_str);
304 self
305 }
306 pub fn forwards_to(&mut self, forwards_to: impl Into<String>) -> &mut Self {
315 self.options.common_opts.forwards_to = Some(forwards_to.into());
316 self
317 }
318
319 pub fn app_protocol(&mut self, app_protocol: impl Into<String>) -> &mut Self {
321 self.options.common_opts.forwards_proto = Some(app_protocol.into());
322 self
323 }
324
325 pub fn verify_upstream_tls(&mut self, verify_upstream_tls: bool) -> &mut Self {
327 self.options
328 .common_opts
329 .set_verify_upstream_tls(verify_upstream_tls);
330 self
331 }
332
333 pub fn scheme(&mut self, scheme: Scheme) -> &mut Self {
335 self.options.scheme = scheme;
336 self
337 }
338
339 pub fn domain(&mut self, domain: impl Into<String>) -> &mut Self {
343 self.options.domain = Some(domain.into());
344 self
345 }
346 pub fn mutual_tlsca(&mut self, mutual_tlsca: Bytes) -> &mut Self {
353 self.options.mutual_tlsca.push(mutual_tlsca);
354 self
355 }
356 pub fn compression(&mut self) -> &mut Self {
360 self.options.compression = true;
361 self
362 }
363 pub fn websocket_tcp_conversion(&mut self) -> &mut Self {
367 self.options.websocket_tcp_conversion = true;
368 self
369 }
370 pub fn circuit_breaker(&mut self, circuit_breaker: f64) -> &mut Self {
375 self.options.circuit_breaker = circuit_breaker;
376 self
377 }
378
379 pub fn host_header_rewrite(&mut self, rewrite: bool) -> &mut Self {
386 self.options.rewrite_host = rewrite;
387 self
388 }
389
390 pub fn request_header(
394 &mut self,
395 name: impl Into<String>,
396 value: impl Into<String>,
397 ) -> &mut Self {
398 self.options.request_headers.add(name, value);
399 self
400 }
401 pub fn response_header(
405 &mut self,
406 name: impl Into<String>,
407 value: impl Into<String>,
408 ) -> &mut Self {
409 self.options.response_headers.add(name, value);
410 self
411 }
412 pub fn remove_request_header(&mut self, name: impl Into<String>) -> &mut Self {
416 self.options.request_headers.remove(name);
417 self
418 }
419 pub fn remove_response_header(&mut self, name: impl Into<String>) -> &mut Self {
423 self.options.response_headers.remove(name);
424 self
425 }
426
427 pub fn basic_auth(
432 &mut self,
433 username: impl Into<String>,
434 password: impl Into<String>,
435 ) -> &mut Self {
436 self.options
437 .basic_auth
438 .push((username.into(), password.into()));
439 self
440 }
441
442 pub fn oauth(&mut self, oauth: impl Borrow<OauthOptions>) -> &mut Self {
446 self.options.oauth = Some(oauth.borrow().to_owned());
447 self
448 }
449
450 pub fn oidc(&mut self, oidc: impl Borrow<OidcOptions>) -> &mut Self {
454 self.options.oidc = Some(oidc.borrow().to_owned());
455 self
456 }
457
458 pub fn webhook_verification(
462 &mut self,
463 provider: impl Into<String>,
464 secret: impl Into<String>,
465 ) -> &mut Self {
466 self.options.webhook_verification = Some(WebhookVerification {
467 provider: provider.into(),
468 secret: secret.into().into(),
469 });
470 self
471 }
472
473 pub fn allow_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
477 self.options.user_agent_filter.allow(regex);
478 self
479 }
480 pub fn deny_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
484 self.options.user_agent_filter.deny(regex);
485 self
486 }
487
488 pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
490 where
491 S: TryInto<Policy>,
492 {
493 self.options.common_opts.policy = Some(s.try_into()?);
494 Ok(self)
495 }
496
497 pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
499 self.options.common_opts.traffic_policy = Some(policy_str.into());
500 self
501 }
502
503 pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
504 self.options.common_opts.for_forwarding_to(to_url);
505 if let Some(host) = to_url.host_str().filter(|_| self.options.rewrite_host) {
506 self.request_header("host", host);
507 }
508 self
509 }
510
511 pub fn pooling_enabled(&mut self, pooling_enabled: impl Into<bool>) -> &mut Self {
513 self.options.common_opts.pooling_enabled = Some(pooling_enabled.into());
514 self
515 }
516}
517
518#[cfg(test)]
519mod test {
520 use super::*;
521 use crate::config::policies::test::POLICY_JSON;
522 const METADATA: &str = "testmeta";
523 const TEST_FORWARD: &str = "testforward";
524 const TEST_FORWARD_PROTO: &str = "http2";
525 const ALLOW_CIDR: &str = "0.0.0.0/0";
526 const DENY_CIDR: &str = "10.1.1.1/32";
527 const CA_CERT: &[u8] = "test ca cert".as_bytes();
528 const CA_CERT2: &[u8] = "test ca cert2".as_bytes();
529 const DOMAIN: &str = "test domain";
530 const ALLOW_AGENT: &str = r"bar/(\d)+";
531 const DENY_AGENT: &str = r"foo/(\d)+";
532
533 #[test]
534 fn test_interface_to_proto() {
535 tunnel_test(
538 &HttpTunnelBuilder {
539 session: None,
540 options: Default::default(),
541 }
542 .allow_user_agent(ALLOW_AGENT)
543 .deny_user_agent(DENY_AGENT)
544 .allow_cidr(ALLOW_CIDR)
545 .deny_cidr(DENY_CIDR)
546 .proxy_proto(ProxyProto::V2)
547 .metadata(METADATA)
548 .scheme(Scheme::from_str("hTtPs").unwrap())
549 .domain(DOMAIN)
550 .mutual_tlsca(CA_CERT.into())
551 .mutual_tlsca(CA_CERT2.into())
552 .compression()
553 .websocket_tcp_conversion()
554 .circuit_breaker(0.5)
555 .request_header("X-Req-Yup", "true")
556 .response_header("X-Res-Yup", "true")
557 .remove_request_header("X-Req-Nope")
558 .remove_response_header("X-Res-Nope")
559 .oauth(OauthOptions::new("google"))
560 .oauth(
561 OauthOptions::new("google")
562 .allow_email("<user>@<domain>")
563 .allow_domain("<domain>")
564 .scope("<scope>"),
565 )
566 .oidc(OidcOptions::new("<url>", "<id>", "<secret>"))
567 .oidc(
568 OidcOptions::new("<url>", "<id>", "<secret>")
569 .allow_email("<user>@<domain>")
570 .allow_domain("<domain>")
571 .scope("<scope>"),
572 )
573 .webhook_verification("twilio", "asdf")
574 .basic_auth("ngrok", "online1line")
575 .forwards_to(TEST_FORWARD)
576 .app_protocol("http2")
577 .policy(POLICY_JSON)
578 .unwrap()
579 .options,
580 );
581 }
582
583 fn tunnel_test<C>(tunnel_cfg: C)
584 where
585 C: TunnelConfig,
586 {
587 assert_eq!(TEST_FORWARD, tunnel_cfg.forwards_to());
588 assert_eq!(TEST_FORWARD_PROTO, tunnel_cfg.forwards_proto());
589 let extra = tunnel_cfg.extra();
590 assert_eq!(String::default(), *extra.token);
591 assert_eq!(METADATA, extra.metadata);
592 assert_eq!(Vec::<String>::new(), extra.bindings);
593 assert_eq!(String::default(), extra.ip_policy_ref);
594
595 assert_eq!("https", tunnel_cfg.proto());
596
597 let opts = tunnel_cfg.opts().unwrap();
598 assert!(matches!(opts, BindOpts::Http { .. }));
599 if let BindOpts::Http(endpoint) = opts {
600 assert_eq!(DOMAIN, endpoint.hostname);
601 assert_eq!(String::default(), endpoint.subdomain);
602 assert!(matches!(endpoint.proxy_proto, ProxyProto::V2));
603
604 let ip_restriction = endpoint.ip_restriction.unwrap();
605 assert_eq!(Vec::from([ALLOW_CIDR]), ip_restriction.allow_cidrs);
606 assert_eq!(Vec::from([DENY_CIDR]), ip_restriction.deny_cidrs);
607
608 let mutual_tls = endpoint.mutual_tls_ca.unwrap();
609 let mut agg = CA_CERT.to_vec();
610 agg.extend(CA_CERT2.to_vec());
611 assert_eq!(agg, mutual_tls.mutual_tls_ca);
612
613 assert!(endpoint.compression.is_some());
614 assert!(endpoint.websocket_tcp_converter.is_some());
615 assert_eq!(0.5f64, endpoint.circuit_breaker.unwrap().error_threshold);
616
617 let request_headers = endpoint.request_headers.unwrap();
618 assert_eq!(["x-req-yup:true"].to_vec(), request_headers.add);
619 assert_eq!(["x-req-nope"].to_vec(), request_headers.remove);
620
621 let response_headers = endpoint.response_headers.unwrap();
622 assert_eq!(["x-res-yup:true"].to_vec(), response_headers.add);
623 assert_eq!(["x-res-nope"].to_vec(), response_headers.remove);
624
625 let webhook = endpoint.webhook_verification.unwrap();
626 assert_eq!("twilio", webhook.provider);
627 assert_eq!("asdf", *webhook.secret);
628 assert!(webhook.sealed_secret.is_empty());
629
630 let creds = endpoint.basic_auth.unwrap().credentials;
631 assert_eq!(1, creds.len());
632 assert_eq!("ngrok", creds[0].username);
633 assert_eq!("online1line", creds[0].cleartext_password);
634 assert!(creds[0].hashed_password.is_empty());
635
636 let oauth = endpoint.oauth.unwrap();
637 assert_eq!("google", oauth.provider);
638 assert_eq!(["<user>@<domain>"].to_vec(), oauth.allow_emails);
639 assert_eq!(["<domain>"].to_vec(), oauth.allow_domains);
640 assert_eq!(["<scope>"].to_vec(), oauth.scopes);
641 assert_eq!(String::default(), oauth.client_id);
642 assert_eq!(String::default(), *oauth.client_secret);
643 assert!(oauth.sealed_client_secret.is_empty());
644
645 let oidc = endpoint.oidc.unwrap();
646 assert_eq!("<url>", oidc.issuer_url);
647 assert_eq!(["<user>@<domain>"].to_vec(), oidc.allow_emails);
648 assert_eq!(["<domain>"].to_vec(), oidc.allow_domains);
649 assert_eq!(["<scope>"].to_vec(), oidc.scopes);
650 assert_eq!("<id>", oidc.client_id);
651 assert_eq!("<secret>", *oidc.client_secret);
652 assert!(oidc.sealed_client_secret.is_empty());
653
654 let user_agent_filter = endpoint.user_agent_filter.unwrap();
655 assert_eq!(Vec::from([ALLOW_AGENT]), user_agent_filter.allow);
656 assert_eq!(Vec::from([DENY_AGENT]), user_agent_filter.deny);
657 }
658
659 assert_eq!(HashMap::new(), tunnel_cfg.labels());
660 }
661
662 #[test]
663 fn test_binding_valid_values() {
664 let mut builder = HttpTunnelBuilder {
665 session: None,
666 options: Default::default(),
667 };
668
669 builder.binding("public");
671 assert_eq!(vec!["public"], builder.options.bindings);
672
673 let mut builder = HttpTunnelBuilder {
675 session: None,
676 options: Default::default(),
677 };
678 builder.binding("internal");
679 assert_eq!(vec!["internal"], builder.options.bindings);
680
681 let mut builder = HttpTunnelBuilder {
683 session: None,
684 options: Default::default(),
685 };
686 builder.binding("kubernetes");
687 assert_eq!(vec!["kubernetes"], builder.options.bindings);
688
689 let mut builder = HttpTunnelBuilder {
691 session: None,
692 options: Default::default(),
693 };
694 builder.binding(Binding::Internal);
695 assert_eq!(vec!["internal"], builder.options.bindings);
696 }
697
698 #[test]
699 #[should_panic(expected = "Invalid binding value")]
700 fn test_binding_invalid_value() {
701 let mut builder = HttpTunnelBuilder {
702 session: None,
703 options: Default::default(),
704 };
705 builder.binding("invalid");
706 }
707
708 #[test]
709 #[should_panic(expected = "binding() can only be called once")]
710 fn test_binding_called_twice() {
711 let mut builder = HttpTunnelBuilder {
712 session: None,
713 options: Default::default(),
714 };
715 builder.binding("public");
716 builder.binding("internal");
717 }
718}