1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 convert::From,
5 str::FromStr,
6};
7
8use bytes::Bytes;
9use thiserror::Error;
10use url::Url;
11
12use super::{
13 common::ProxyProto,
14 Policy,
15};
16#[allow(unused_imports)]
18use crate::config::{
19 ForwarderBuilder,
20 TunnelBuilder,
21};
22use crate::{
23 config::{
24 common::{
25 default_forwards_to,
26 Binding,
27 CommonOpts,
28 TunnelConfig,
29 },
30 headers::Headers,
31 oauth::OauthOptions,
32 oidc::OidcOptions,
33 webhook_verification::WebhookVerification,
34 },
35 internals::proto::{
36 BasicAuth,
37 BasicAuthCredential,
38 BindExtra,
39 BindOpts,
40 CircuitBreaker,
41 Compression,
42 HttpEndpoint,
43 UserAgentFilter,
44 WebsocketTcpConverter,
45 },
46 tunnel::HttpTunnel,
47 Session,
48};
49
50#[derive(Debug, Clone, Error)]
52#[error("invalid scheme string: {}", .0)]
53pub struct InvalidSchemeString(String);
54
55#[derive(Clone, Default, Eq, PartialEq)]
59pub enum Scheme {
60 HTTP,
62 #[default]
64 HTTPS,
65}
66
67impl FromStr for Scheme {
68 type Err = InvalidSchemeString;
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 use Scheme::*;
71 Ok(match s.to_uppercase().as_str() {
72 "HTTP" => HTTP,
73 "HTTPS" => HTTPS,
74 _ => return Err(InvalidSchemeString(s.into())),
75 })
76 }
77}
78
79#[derive(Clone, Default)]
81pub(crate) struct UaFilter {
82 pub(crate) allow: Vec<String>,
84 pub(crate) deny: Vec<String>,
87}
88
89impl UaFilter {
90 pub(crate) fn allow(&mut self, allow: impl Into<String>) {
91 self.allow.push(allow.into());
92 }
93 pub(crate) fn deny(&mut self, deny: impl Into<String>) {
94 self.deny.push(deny.into());
95 }
96}
97
98impl From<UaFilter> for UserAgentFilter {
99 fn from(ua: UaFilter) -> Self {
100 UserAgentFilter {
101 allow: ua.allow,
102 deny: ua.deny,
103 }
104 }
105}
106
107#[derive(Default, Clone)]
109struct HttpOptions {
110 pub(crate) common_opts: CommonOpts,
111 pub(crate) scheme: Scheme,
112 pub(crate) domain: Option<String>,
113 pub(crate) mutual_tlsca: Vec<bytes::Bytes>,
114 pub(crate) compression: bool,
115 pub(crate) websocket_tcp_conversion: bool,
116 pub(crate) circuit_breaker: f64,
117 pub(crate) request_headers: Headers,
118 pub(crate) response_headers: Headers,
119 pub(crate) rewrite_host: bool,
120 pub(crate) basic_auth: Vec<(String, String)>,
121 pub(crate) oauth: Option<OauthOptions>,
122 pub(crate) oidc: Option<OidcOptions>,
123 pub(crate) webhook_verification: Option<WebhookVerification>,
124 pub(crate) user_agent_filter: UaFilter,
126 pub(crate) bindings: Vec<String>,
127}
128
129impl HttpOptions {
130 fn user_agent_filter(&self) -> Option<UserAgentFilter> {
131 (!self.user_agent_filter.allow.is_empty() || !self.user_agent_filter.deny.is_empty())
132 .then_some(self.user_agent_filter.clone().into())
133 }
134}
135
136impl TunnelConfig for HttpOptions {
137 fn forwards_to(&self) -> String {
138 self.common_opts
139 .forwards_to
140 .clone()
141 .unwrap_or(default_forwards_to().into())
142 }
143
144 fn forwards_proto(&self) -> String {
145 self.common_opts.forwards_proto.clone().unwrap_or_default()
146 }
147
148 fn verify_upstream_tls(&self) -> bool {
149 self.common_opts.verify_upstream_tls()
150 }
151
152 fn extra(&self) -> BindExtra {
153 BindExtra {
154 token: Default::default(),
155 ip_policy_ref: Default::default(),
156 metadata: self.common_opts.metadata.clone().unwrap_or_default(),
157 bindings: self.bindings.clone(),
158 pooling_enabled: self.common_opts.pooling_enabled.unwrap_or(false),
159 }
160 }
161 fn proto(&self) -> String {
162 if self.scheme == Scheme::HTTP {
163 return "http".into();
164 }
165 "https".into()
166 }
167 fn opts(&self) -> Option<BindOpts> {
168 let http_endpoint = HttpEndpoint {
169 proxy_proto: self.common_opts.proxy_proto,
170 domain: self.domain.clone().unwrap_or_default(),
171 hostname: String::new(),
172 compression: self.compression.then_some(Compression {}),
173 circuit_breaker: (self.circuit_breaker != 0f64).then_some(CircuitBreaker {
174 error_threshold: self.circuit_breaker,
175 }),
176 ip_restriction: self.common_opts.ip_restriction(),
177 basic_auth: (!self.basic_auth.is_empty()).then_some(self.basic_auth.as_slice().into()),
178 oauth: self.oauth.clone().map(From::from),
179 oidc: self.oidc.clone().map(From::from),
180 webhook_verification: self.webhook_verification.clone().map(From::from),
181 mutual_tls_ca: (!self.mutual_tlsca.is_empty())
182 .then_some(self.mutual_tlsca.as_slice().into()),
183 request_headers: self
184 .request_headers
185 .has_entries()
186 .then_some(self.request_headers.clone().into()),
187 response_headers: self
188 .response_headers
189 .has_entries()
190 .then_some(self.response_headers.clone().into()),
191 websocket_tcp_converter: self
192 .websocket_tcp_conversion
193 .then_some(WebsocketTcpConverter {}),
194 user_agent_filter: self.user_agent_filter(),
195 traffic_policy: if self.common_opts.traffic_policy.is_some() {
196 self.common_opts.traffic_policy.clone().map(From::from)
197 } else if self.common_opts.policy.is_some() {
198 self.common_opts.policy.clone().map(From::from)
199 } else {
200 None
201 },
202 ..Default::default()
203 };
204
205 Some(BindOpts::Http(http_endpoint))
206 }
207 fn labels(&self) -> HashMap<String, String> {
208 HashMap::new()
209 }
210}
211
212impl From<&[(String, String)]> for BasicAuth {
214 fn from(v: &[(String, String)]) -> Self {
215 BasicAuth {
216 credentials: v.iter().cloned().map(From::from).collect(),
217 }
218 }
219}
220
221impl From<(String, String)> for BasicAuthCredential {
223 fn from(b: (String, String)) -> Self {
224 BasicAuthCredential {
225 username: b.0,
226 cleartext_password: b.1,
227 hashed_password: vec![], }
229 }
230}
231
232impl_builder! {
233 HttpTunnelBuilder, HttpOptions, HttpTunnel, endpoint
237}
238
239impl HttpTunnelBuilder {
240 pub fn allow_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
244 self.options.common_opts.cidr_restrictions.allow(cidr);
245 self
246 }
247 pub fn deny_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
251 self.options.common_opts.cidr_restrictions.deny(cidr);
252 self
253 }
254 pub fn proxy_proto(&mut self, proxy_proto: ProxyProto) -> &mut Self {
256 self.options.common_opts.proxy_proto = proxy_proto;
257 self
258 }
259 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
263 self.options.common_opts.metadata = Some(metadata.into());
264 self
265 }
266
267 pub fn binding(&mut self, binding: impl Into<String>) -> &mut Self {
298 if !self.options.bindings.is_empty() {
299 panic!("binding() can only be called once");
300 }
301 let binding_str = binding.into();
302 if let Err(e) = Binding::validate(&binding_str) {
303 panic!("{}", e);
304 }
305 self.options.bindings.push(binding_str);
306 self
307 }
308 pub fn forwards_to(&mut self, forwards_to: impl Into<String>) -> &mut Self {
317 self.options.common_opts.forwards_to = Some(forwards_to.into());
318 self
319 }
320
321 pub fn app_protocol(&mut self, app_protocol: impl Into<String>) -> &mut Self {
323 self.options.common_opts.forwards_proto = Some(app_protocol.into());
324 self
325 }
326
327 pub fn verify_upstream_tls(&mut self, verify_upstream_tls: bool) -> &mut Self {
329 self.options
330 .common_opts
331 .set_verify_upstream_tls(verify_upstream_tls);
332 self
333 }
334
335 pub fn scheme(&mut self, scheme: Scheme) -> &mut Self {
337 self.options.scheme = scheme;
338 self
339 }
340
341 pub fn domain(&mut self, domain: impl Into<String>) -> &mut Self {
345 self.options.domain = Some(domain.into());
346 self
347 }
348 pub fn mutual_tlsca(&mut self, mutual_tlsca: Bytes) -> &mut Self {
355 self.options.mutual_tlsca.push(mutual_tlsca);
356 self
357 }
358 pub fn compression(&mut self) -> &mut Self {
362 self.options.compression = true;
363 self
364 }
365 pub fn websocket_tcp_conversion(&mut self) -> &mut Self {
369 self.options.websocket_tcp_conversion = true;
370 self
371 }
372 pub fn circuit_breaker(&mut self, circuit_breaker: f64) -> &mut Self {
377 self.options.circuit_breaker = circuit_breaker;
378 self
379 }
380
381 pub fn host_header_rewrite(&mut self, rewrite: bool) -> &mut Self {
388 self.options.rewrite_host = rewrite;
389 self
390 }
391
392 pub fn request_header(
396 &mut self,
397 name: impl Into<String>,
398 value: impl Into<String>,
399 ) -> &mut Self {
400 self.options.request_headers.add(name, value);
401 self
402 }
403 pub fn response_header(
407 &mut self,
408 name: impl Into<String>,
409 value: impl Into<String>,
410 ) -> &mut Self {
411 self.options.response_headers.add(name, value);
412 self
413 }
414 pub fn remove_request_header(&mut self, name: impl Into<String>) -> &mut Self {
418 self.options.request_headers.remove(name);
419 self
420 }
421 pub fn remove_response_header(&mut self, name: impl Into<String>) -> &mut Self {
425 self.options.response_headers.remove(name);
426 self
427 }
428
429 pub fn basic_auth(
434 &mut self,
435 username: impl Into<String>,
436 password: impl Into<String>,
437 ) -> &mut Self {
438 self.options
439 .basic_auth
440 .push((username.into(), password.into()));
441 self
442 }
443
444 pub fn oauth(&mut self, oauth: impl Borrow<OauthOptions>) -> &mut Self {
448 self.options.oauth = Some(oauth.borrow().to_owned());
449 self
450 }
451
452 pub fn oidc(&mut self, oidc: impl Borrow<OidcOptions>) -> &mut Self {
456 self.options.oidc = Some(oidc.borrow().to_owned());
457 self
458 }
459
460 pub fn webhook_verification(
464 &mut self,
465 provider: impl Into<String>,
466 secret: impl Into<String>,
467 ) -> &mut Self {
468 self.options.webhook_verification = Some(WebhookVerification {
469 provider: provider.into(),
470 secret: secret.into().into(),
471 });
472 self
473 }
474
475 pub fn allow_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
479 self.options.user_agent_filter.allow(regex);
480 self
481 }
482 pub fn deny_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
486 self.options.user_agent_filter.deny(regex);
487 self
488 }
489
490 pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
492 where
493 S: TryInto<Policy>,
494 {
495 self.options.common_opts.policy = Some(s.try_into()?);
496 Ok(self)
497 }
498
499 pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
501 self.options.common_opts.traffic_policy = Some(policy_str.into());
502 self
503 }
504
505 pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
506 self.options.common_opts.for_forwarding_to(to_url);
507 if let Some(host) = to_url.host_str().filter(|_| self.options.rewrite_host) {
508 self.request_header("host", host);
509 }
510 self
511 }
512
513 pub fn pooling_enabled(&mut self, pooling_enabled: impl Into<bool>) -> &mut Self {
515 self.options.common_opts.pooling_enabled = Some(pooling_enabled.into());
516 self
517 }
518}
519
520#[cfg(test)]
521mod test {
522 use super::*;
523 use crate::config::policies::test::POLICY_JSON;
524 const METADATA: &str = "testmeta";
525 const TEST_FORWARD: &str = "testforward";
526 const TEST_FORWARD_PROTO: &str = "http2";
527 const ALLOW_CIDR: &str = "0.0.0.0/0";
528 const DENY_CIDR: &str = "10.1.1.1/32";
529 const CA_CERT: &[u8] = "test ca cert".as_bytes();
530 const CA_CERT2: &[u8] = "test ca cert2".as_bytes();
531 const DOMAIN: &str = "test domain";
532 const ALLOW_AGENT: &str = r"bar/(\d)+";
533 const DENY_AGENT: &str = r"foo/(\d)+";
534
535 #[test]
536 fn test_interface_to_proto() {
537 tunnel_test(
540 &HttpTunnelBuilder {
541 session: None,
542 options: Default::default(),
543 }
544 .allow_user_agent(ALLOW_AGENT)
545 .deny_user_agent(DENY_AGENT)
546 .allow_cidr(ALLOW_CIDR)
547 .deny_cidr(DENY_CIDR)
548 .proxy_proto(ProxyProto::V2)
549 .metadata(METADATA)
550 .scheme(Scheme::from_str("hTtPs").unwrap())
551 .domain(DOMAIN)
552 .mutual_tlsca(CA_CERT.into())
553 .mutual_tlsca(CA_CERT2.into())
554 .compression()
555 .websocket_tcp_conversion()
556 .circuit_breaker(0.5)
557 .request_header("X-Req-Yup", "true")
558 .response_header("X-Res-Yup", "true")
559 .remove_request_header("X-Req-Nope")
560 .remove_response_header("X-Res-Nope")
561 .oauth(OauthOptions::new("google"))
562 .oauth(
563 OauthOptions::new("google")
564 .allow_email("<user>@<domain>")
565 .allow_domain("<domain>")
566 .scope("<scope>"),
567 )
568 .oidc(OidcOptions::new("<url>", "<id>", "<secret>"))
569 .oidc(
570 OidcOptions::new("<url>", "<id>", "<secret>")
571 .allow_email("<user>@<domain>")
572 .allow_domain("<domain>")
573 .scope("<scope>"),
574 )
575 .webhook_verification("twilio", "asdf")
576 .basic_auth("ngrok", "online1line")
577 .forwards_to(TEST_FORWARD)
578 .app_protocol("http2")
579 .policy(POLICY_JSON)
580 .unwrap()
581 .options,
582 );
583 }
584
585 fn tunnel_test<C>(tunnel_cfg: C)
586 where
587 C: TunnelConfig,
588 {
589 assert_eq!(TEST_FORWARD, tunnel_cfg.forwards_to());
590 assert_eq!(TEST_FORWARD_PROTO, tunnel_cfg.forwards_proto());
591 let extra = tunnel_cfg.extra();
592 assert_eq!(String::default(), *extra.token);
593 assert_eq!(METADATA, extra.metadata);
594 assert_eq!(Vec::<String>::new(), extra.bindings);
595 assert_eq!(String::default(), extra.ip_policy_ref);
596
597 assert_eq!("https", tunnel_cfg.proto());
598
599 let opts = tunnel_cfg.opts().unwrap();
600 assert!(matches!(opts, BindOpts::Http { .. }));
601 if let BindOpts::Http(endpoint) = opts {
602 assert_eq!(DOMAIN, endpoint.domain);
603 assert_eq!(String::default(), endpoint.subdomain);
604 assert!(matches!(endpoint.proxy_proto, ProxyProto::V2));
605
606 let ip_restriction = endpoint.ip_restriction.unwrap();
607 assert_eq!(Vec::from([ALLOW_CIDR]), ip_restriction.allow_cidrs);
608 assert_eq!(Vec::from([DENY_CIDR]), ip_restriction.deny_cidrs);
609
610 let mutual_tls = endpoint.mutual_tls_ca.unwrap();
611 let mut agg = CA_CERT.to_vec();
612 agg.extend(CA_CERT2.to_vec());
613 assert_eq!(agg, mutual_tls.mutual_tls_ca);
614
615 assert!(endpoint.compression.is_some());
616 assert!(endpoint.websocket_tcp_converter.is_some());
617 assert_eq!(0.5f64, endpoint.circuit_breaker.unwrap().error_threshold);
618
619 let request_headers = endpoint.request_headers.unwrap();
620 assert_eq!(["x-req-yup:true"].to_vec(), request_headers.add);
621 assert_eq!(["x-req-nope"].to_vec(), request_headers.remove);
622
623 let response_headers = endpoint.response_headers.unwrap();
624 assert_eq!(["x-res-yup:true"].to_vec(), response_headers.add);
625 assert_eq!(["x-res-nope"].to_vec(), response_headers.remove);
626
627 let webhook = endpoint.webhook_verification.unwrap();
628 assert_eq!("twilio", webhook.provider);
629 assert_eq!("asdf", *webhook.secret);
630 assert!(webhook.sealed_secret.is_empty());
631
632 let creds = endpoint.basic_auth.unwrap().credentials;
633 assert_eq!(1, creds.len());
634 assert_eq!("ngrok", creds[0].username);
635 assert_eq!("online1line", creds[0].cleartext_password);
636 assert!(creds[0].hashed_password.is_empty());
637
638 let oauth = endpoint.oauth.unwrap();
639 assert_eq!("google", oauth.provider);
640 assert_eq!(["<user>@<domain>"].to_vec(), oauth.allow_emails);
641 assert_eq!(["<domain>"].to_vec(), oauth.allow_domains);
642 assert_eq!(["<scope>"].to_vec(), oauth.scopes);
643 assert_eq!(String::default(), oauth.client_id);
644 assert_eq!(String::default(), *oauth.client_secret);
645 assert!(oauth.sealed_client_secret.is_empty());
646
647 let oidc = endpoint.oidc.unwrap();
648 assert_eq!("<url>", oidc.issuer_url);
649 assert_eq!(["<user>@<domain>"].to_vec(), oidc.allow_emails);
650 assert_eq!(["<domain>"].to_vec(), oidc.allow_domains);
651 assert_eq!(["<scope>"].to_vec(), oidc.scopes);
652 assert_eq!("<id>", oidc.client_id);
653 assert_eq!("<secret>", *oidc.client_secret);
654 assert!(oidc.sealed_client_secret.is_empty());
655
656 let user_agent_filter = endpoint.user_agent_filter.unwrap();
657 assert_eq!(Vec::from([ALLOW_AGENT]), user_agent_filter.allow);
658 assert_eq!(Vec::from([DENY_AGENT]), user_agent_filter.deny);
659 }
660
661 assert_eq!(HashMap::new(), tunnel_cfg.labels());
662 }
663
664 #[test]
665 fn test_binding_valid_values() {
666 let mut builder = HttpTunnelBuilder {
667 session: None,
668 options: Default::default(),
669 };
670
671 builder.binding("public");
673 assert_eq!(vec!["public"], builder.options.bindings);
674
675 let mut builder = HttpTunnelBuilder {
677 session: None,
678 options: Default::default(),
679 };
680 builder.binding("internal");
681 assert_eq!(vec!["internal"], builder.options.bindings);
682
683 let mut builder = HttpTunnelBuilder {
685 session: None,
686 options: Default::default(),
687 };
688 builder.binding("kubernetes");
689 assert_eq!(vec!["kubernetes"], builder.options.bindings);
690
691 let mut builder = HttpTunnelBuilder {
693 session: None,
694 options: Default::default(),
695 };
696 builder.binding(Binding::Internal);
697 assert_eq!(vec!["internal"], builder.options.bindings);
698 }
699
700 #[test]
701 #[should_panic(expected = "Invalid binding value")]
702 fn test_binding_invalid_value() {
703 let mut builder = HttpTunnelBuilder {
704 session: None,
705 options: Default::default(),
706 };
707 builder.binding("invalid");
708 }
709
710 #[test]
711 #[should_panic(expected = "binding() can only be called once")]
712 fn test_binding_called_twice() {
713 let mut builder = HttpTunnelBuilder {
714 session: None,
715 options: Default::default(),
716 };
717 builder.binding("public");
718 builder.binding("internal");
719 }
720
721 #[test]
722 fn test_binding_with_domain() {
723 let mut builder = HttpTunnelBuilder {
724 session: None,
725 options: Default::default(),
726 };
727 builder.binding("internal").domain("foo.internal");
728
729 assert_eq!(vec!["internal"], builder.options.bindings);
731 assert_eq!(Some("foo.internal".to_string()), builder.options.domain);
732
733 let extra = builder.options.extra();
735 assert_eq!(vec!["internal"], extra.bindings);
736
737 let opts = builder.options.opts().unwrap();
738 if let BindOpts::Http(endpoint) = opts {
739 assert_eq!("foo.internal", endpoint.domain);
740 } else {
741 panic!("Expected Http endpoint");
742 }
743 }
744}