1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 convert::From,
5 str::FromStr,
6};
7
8use bytes::Bytes;
9use thiserror::Error;
10use url::Url;
11
12use super::{
13 common::ProxyProto,
14 Policy,
15};
16#[allow(unused_imports)]
18use crate::config::{
19 ForwarderBuilder,
20 TunnelBuilder,
21};
22use crate::{
23 config::{
24 common::{
25 default_forwards_to,
26 CommonOpts,
27 TunnelConfig,
28 },
29 headers::Headers,
30 oauth::OauthOptions,
31 oidc::OidcOptions,
32 webhook_verification::WebhookVerification,
33 },
34 internals::proto::{
35 BasicAuth,
36 BasicAuthCredential,
37 BindExtra,
38 BindOpts,
39 CircuitBreaker,
40 Compression,
41 HttpEndpoint,
42 UserAgentFilter,
43 WebsocketTcpConverter,
44 },
45 tunnel::HttpTunnel,
46 Session,
47};
48
49#[derive(Debug, Clone, Error)]
51#[error("invalid scheme string: {}", .0)]
52pub struct InvalidSchemeString(String);
53
54#[derive(Clone, Default, Eq, PartialEq)]
58pub enum Scheme {
59 HTTP,
61 #[default]
63 HTTPS,
64}
65
66impl FromStr for Scheme {
67 type Err = InvalidSchemeString;
68 fn from_str(s: &str) -> Result<Self, Self::Err> {
69 use Scheme::*;
70 Ok(match s.to_uppercase().as_str() {
71 "HTTP" => HTTP,
72 "HTTPS" => HTTPS,
73 _ => return Err(InvalidSchemeString(s.into())),
74 })
75 }
76}
77
78#[derive(Clone, Default)]
80pub(crate) struct UaFilter {
81 pub(crate) allow: Vec<String>,
83 pub(crate) deny: Vec<String>,
86}
87
88impl UaFilter {
89 pub(crate) fn allow(&mut self, allow: impl Into<String>) {
90 self.allow.push(allow.into());
91 }
92 pub(crate) fn deny(&mut self, deny: impl Into<String>) {
93 self.deny.push(deny.into());
94 }
95}
96
97impl From<UaFilter> for UserAgentFilter {
98 fn from(ua: UaFilter) -> Self {
99 UserAgentFilter {
100 allow: ua.allow,
101 deny: ua.deny,
102 }
103 }
104}
105
106#[derive(Default, Clone)]
108struct HttpOptions {
109 pub(crate) common_opts: CommonOpts,
110 pub(crate) scheme: Scheme,
111 pub(crate) domain: Option<String>,
112 pub(crate) mutual_tlsca: Vec<bytes::Bytes>,
113 pub(crate) compression: bool,
114 pub(crate) websocket_tcp_conversion: bool,
115 pub(crate) circuit_breaker: f64,
116 pub(crate) request_headers: Headers,
117 pub(crate) response_headers: Headers,
118 pub(crate) rewrite_host: bool,
119 pub(crate) basic_auth: Vec<(String, String)>,
120 pub(crate) oauth: Option<OauthOptions>,
121 pub(crate) oidc: Option<OidcOptions>,
122 pub(crate) webhook_verification: Option<WebhookVerification>,
123 pub(crate) user_agent_filter: UaFilter,
125 pub(crate) bindings: Vec<String>,
126}
127
128impl HttpOptions {
129 fn user_agent_filter(&self) -> Option<UserAgentFilter> {
130 (!self.user_agent_filter.allow.is_empty() || !self.user_agent_filter.deny.is_empty())
131 .then_some(self.user_agent_filter.clone().into())
132 }
133}
134
135impl TunnelConfig for HttpOptions {
136 fn forwards_to(&self) -> String {
137 self.common_opts
138 .forwards_to
139 .clone()
140 .unwrap_or(default_forwards_to().into())
141 }
142
143 fn forwards_proto(&self) -> String {
144 self.common_opts.forwards_proto.clone().unwrap_or_default()
145 }
146
147 fn verify_upstream_tls(&self) -> bool {
148 self.common_opts.verify_upstream_tls()
149 }
150
151 fn extra(&self) -> BindExtra {
152 BindExtra {
153 token: Default::default(),
154 ip_policy_ref: Default::default(),
155 metadata: self.common_opts.metadata.clone().unwrap_or_default(),
156 bindings: self.bindings.clone(),
157 pooling_enabled: self.common_opts.pooling_enabled.unwrap_or(false),
158 }
159 }
160 fn proto(&self) -> String {
161 if self.scheme == Scheme::HTTP {
162 return "http".into();
163 }
164 "https".into()
165 }
166 fn opts(&self) -> Option<BindOpts> {
167 let http_endpoint = HttpEndpoint {
168 proxy_proto: self.common_opts.proxy_proto,
169 hostname: self.domain.clone().unwrap_or_default(),
170 compression: self.compression.then_some(Compression {}),
171 circuit_breaker: (self.circuit_breaker != 0f64).then_some(CircuitBreaker {
172 error_threshold: self.circuit_breaker,
173 }),
174 ip_restriction: self.common_opts.ip_restriction(),
175 basic_auth: (!self.basic_auth.is_empty()).then_some(self.basic_auth.as_slice().into()),
176 oauth: self.oauth.clone().map(From::from),
177 oidc: self.oidc.clone().map(From::from),
178 webhook_verification: self.webhook_verification.clone().map(From::from),
179 mutual_tls_ca: (!self.mutual_tlsca.is_empty())
180 .then_some(self.mutual_tlsca.as_slice().into()),
181 request_headers: self
182 .request_headers
183 .has_entries()
184 .then_some(self.request_headers.clone().into()),
185 response_headers: self
186 .response_headers
187 .has_entries()
188 .then_some(self.response_headers.clone().into()),
189 websocket_tcp_converter: self
190 .websocket_tcp_conversion
191 .then_some(WebsocketTcpConverter {}),
192 user_agent_filter: self.user_agent_filter(),
193 traffic_policy: if self.common_opts.traffic_policy.is_some() {
194 self.common_opts.traffic_policy.clone().map(From::from)
195 } else if self.common_opts.policy.is_some() {
196 self.common_opts.policy.clone().map(From::from)
197 } else {
198 None
199 },
200 ..Default::default()
201 };
202
203 Some(BindOpts::Http(http_endpoint))
204 }
205 fn labels(&self) -> HashMap<String, String> {
206 HashMap::new()
207 }
208}
209
210impl From<&[(String, String)]> for BasicAuth {
212 fn from(v: &[(String, String)]) -> Self {
213 BasicAuth {
214 credentials: v.iter().cloned().map(From::from).collect(),
215 }
216 }
217}
218
219impl From<(String, String)> for BasicAuthCredential {
221 fn from(b: (String, String)) -> Self {
222 BasicAuthCredential {
223 username: b.0,
224 cleartext_password: b.1,
225 hashed_password: vec![], }
227 }
228}
229
230impl_builder! {
231 HttpTunnelBuilder, HttpOptions, HttpTunnel, endpoint
235}
236
237impl HttpTunnelBuilder {
238 pub fn allow_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
242 self.options.common_opts.cidr_restrictions.allow(cidr);
243 self
244 }
245 pub fn deny_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
249 self.options.common_opts.cidr_restrictions.deny(cidr);
250 self
251 }
252 pub fn proxy_proto(&mut self, proxy_proto: ProxyProto) -> &mut Self {
254 self.options.common_opts.proxy_proto = proxy_proto;
255 self
256 }
257 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
261 self.options.common_opts.metadata = Some(metadata.into());
262 self
263 }
264 pub fn binding(&mut self, binding: impl Into<String>) -> &mut Self {
266 self.options.bindings.push(binding.into());
267 self
268 }
269 pub fn forwards_to(&mut self, forwards_to: impl Into<String>) -> &mut Self {
278 self.options.common_opts.forwards_to = Some(forwards_to.into());
279 self
280 }
281
282 pub fn app_protocol(&mut self, app_protocol: impl Into<String>) -> &mut Self {
284 self.options.common_opts.forwards_proto = Some(app_protocol.into());
285 self
286 }
287
288 pub fn verify_upstream_tls(&mut self, verify_upstream_tls: bool) -> &mut Self {
290 self.options
291 .common_opts
292 .set_verify_upstream_tls(verify_upstream_tls);
293 self
294 }
295
296 pub fn scheme(&mut self, scheme: Scheme) -> &mut Self {
298 self.options.scheme = scheme;
299 self
300 }
301
302 pub fn domain(&mut self, domain: impl Into<String>) -> &mut Self {
306 self.options.domain = Some(domain.into());
307 self
308 }
309 pub fn mutual_tlsca(&mut self, mutual_tlsca: Bytes) -> &mut Self {
316 self.options.mutual_tlsca.push(mutual_tlsca);
317 self
318 }
319 pub fn compression(&mut self) -> &mut Self {
323 self.options.compression = true;
324 self
325 }
326 pub fn websocket_tcp_conversion(&mut self) -> &mut Self {
330 self.options.websocket_tcp_conversion = true;
331 self
332 }
333 pub fn circuit_breaker(&mut self, circuit_breaker: f64) -> &mut Self {
338 self.options.circuit_breaker = circuit_breaker;
339 self
340 }
341
342 pub fn host_header_rewrite(&mut self, rewrite: bool) -> &mut Self {
349 self.options.rewrite_host = rewrite;
350 self
351 }
352
353 pub fn request_header(
357 &mut self,
358 name: impl Into<String>,
359 value: impl Into<String>,
360 ) -> &mut Self {
361 self.options.request_headers.add(name, value);
362 self
363 }
364 pub fn response_header(
368 &mut self,
369 name: impl Into<String>,
370 value: impl Into<String>,
371 ) -> &mut Self {
372 self.options.response_headers.add(name, value);
373 self
374 }
375 pub fn remove_request_header(&mut self, name: impl Into<String>) -> &mut Self {
379 self.options.request_headers.remove(name);
380 self
381 }
382 pub fn remove_response_header(&mut self, name: impl Into<String>) -> &mut Self {
386 self.options.response_headers.remove(name);
387 self
388 }
389
390 pub fn basic_auth(
395 &mut self,
396 username: impl Into<String>,
397 password: impl Into<String>,
398 ) -> &mut Self {
399 self.options
400 .basic_auth
401 .push((username.into(), password.into()));
402 self
403 }
404
405 pub fn oauth(&mut self, oauth: impl Borrow<OauthOptions>) -> &mut Self {
409 self.options.oauth = Some(oauth.borrow().to_owned());
410 self
411 }
412
413 pub fn oidc(&mut self, oidc: impl Borrow<OidcOptions>) -> &mut Self {
417 self.options.oidc = Some(oidc.borrow().to_owned());
418 self
419 }
420
421 pub fn webhook_verification(
425 &mut self,
426 provider: impl Into<String>,
427 secret: impl Into<String>,
428 ) -> &mut Self {
429 self.options.webhook_verification = Some(WebhookVerification {
430 provider: provider.into(),
431 secret: secret.into().into(),
432 });
433 self
434 }
435
436 pub fn allow_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
440 self.options.user_agent_filter.allow(regex);
441 self
442 }
443 pub fn deny_user_agent(&mut self, regex: impl Into<String>) -> &mut Self {
447 self.options.user_agent_filter.deny(regex);
448 self
449 }
450
451 pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
453 where
454 S: TryInto<Policy>,
455 {
456 self.options.common_opts.policy = Some(s.try_into()?);
457 Ok(self)
458 }
459
460 pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
462 self.options.common_opts.traffic_policy = Some(policy_str.into());
463 self
464 }
465
466 pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
467 self.options.common_opts.for_forwarding_to(to_url);
468 if let Some(host) = to_url.host_str().filter(|_| self.options.rewrite_host) {
469 self.request_header("host", host);
470 }
471 self
472 }
473
474 pub fn pooling_enabled(&mut self, pooling_enabled: impl Into<bool>) -> &mut Self {
476 self.options.common_opts.pooling_enabled = Some(pooling_enabled.into());
477 self
478 }
479}
480
481#[cfg(test)]
482mod test {
483 use super::*;
484 use crate::config::policies::test::POLICY_JSON;
485 const BINDING: &str = "public";
486 const METADATA: &str = "testmeta";
487 const TEST_FORWARD: &str = "testforward";
488 const TEST_FORWARD_PROTO: &str = "http2";
489 const ALLOW_CIDR: &str = "0.0.0.0/0";
490 const DENY_CIDR: &str = "10.1.1.1/32";
491 const CA_CERT: &[u8] = "test ca cert".as_bytes();
492 const CA_CERT2: &[u8] = "test ca cert2".as_bytes();
493 const DOMAIN: &str = "test domain";
494 const ALLOW_AGENT: &str = r"bar/(\d)+";
495 const DENY_AGENT: &str = r"foo/(\d)+";
496
497 #[test]
498 fn test_interface_to_proto() {
499 tunnel_test(
502 &HttpTunnelBuilder {
503 session: None,
504 options: Default::default(),
505 }
506 .allow_user_agent(ALLOW_AGENT)
507 .deny_user_agent(DENY_AGENT)
508 .allow_cidr(ALLOW_CIDR)
509 .deny_cidr(DENY_CIDR)
510 .proxy_proto(ProxyProto::V2)
511 .metadata(METADATA)
512 .binding(BINDING)
513 .scheme(Scheme::from_str("hTtPs").unwrap())
514 .domain(DOMAIN)
515 .mutual_tlsca(CA_CERT.into())
516 .mutual_tlsca(CA_CERT2.into())
517 .compression()
518 .websocket_tcp_conversion()
519 .circuit_breaker(0.5)
520 .request_header("X-Req-Yup", "true")
521 .response_header("X-Res-Yup", "true")
522 .remove_request_header("X-Req-Nope")
523 .remove_response_header("X-Res-Nope")
524 .oauth(OauthOptions::new("google"))
525 .oauth(
526 OauthOptions::new("google")
527 .allow_email("<user>@<domain>")
528 .allow_domain("<domain>")
529 .scope("<scope>"),
530 )
531 .oidc(OidcOptions::new("<url>", "<id>", "<secret>"))
532 .oidc(
533 OidcOptions::new("<url>", "<id>", "<secret>")
534 .allow_email("<user>@<domain>")
535 .allow_domain("<domain>")
536 .scope("<scope>"),
537 )
538 .webhook_verification("twilio", "asdf")
539 .basic_auth("ngrok", "online1line")
540 .forwards_to(TEST_FORWARD)
541 .app_protocol("http2")
542 .policy(POLICY_JSON)
543 .unwrap()
544 .options,
545 );
546 }
547
548 fn tunnel_test<C>(tunnel_cfg: C)
549 where
550 C: TunnelConfig,
551 {
552 assert_eq!(TEST_FORWARD, tunnel_cfg.forwards_to());
553 assert_eq!(TEST_FORWARD_PROTO, tunnel_cfg.forwards_proto());
554 let extra = tunnel_cfg.extra();
555 assert_eq!(String::default(), *extra.token);
556 assert_eq!(METADATA, extra.metadata);
557 assert_eq!(Vec::from([BINDING]), extra.bindings);
558 assert_eq!(String::default(), extra.ip_policy_ref);
559
560 assert_eq!("https", tunnel_cfg.proto());
561
562 let opts = tunnel_cfg.opts().unwrap();
563 assert!(matches!(opts, BindOpts::Http { .. }));
564 if let BindOpts::Http(endpoint) = opts {
565 assert_eq!(DOMAIN, endpoint.hostname);
566 assert_eq!(String::default(), endpoint.subdomain);
567 assert!(matches!(endpoint.proxy_proto, ProxyProto::V2));
568
569 let ip_restriction = endpoint.ip_restriction.unwrap();
570 assert_eq!(Vec::from([ALLOW_CIDR]), ip_restriction.allow_cidrs);
571 assert_eq!(Vec::from([DENY_CIDR]), ip_restriction.deny_cidrs);
572
573 let mutual_tls = endpoint.mutual_tls_ca.unwrap();
574 let mut agg = CA_CERT.to_vec();
575 agg.extend(CA_CERT2.to_vec());
576 assert_eq!(agg, mutual_tls.mutual_tls_ca);
577
578 assert!(endpoint.compression.is_some());
579 assert!(endpoint.websocket_tcp_converter.is_some());
580 assert_eq!(0.5f64, endpoint.circuit_breaker.unwrap().error_threshold);
581
582 let request_headers = endpoint.request_headers.unwrap();
583 assert_eq!(["x-req-yup:true"].to_vec(), request_headers.add);
584 assert_eq!(["x-req-nope"].to_vec(), request_headers.remove);
585
586 let response_headers = endpoint.response_headers.unwrap();
587 assert_eq!(["x-res-yup:true"].to_vec(), response_headers.add);
588 assert_eq!(["x-res-nope"].to_vec(), response_headers.remove);
589
590 let webhook = endpoint.webhook_verification.unwrap();
591 assert_eq!("twilio", webhook.provider);
592 assert_eq!("asdf", *webhook.secret);
593 assert!(webhook.sealed_secret.is_empty());
594
595 let creds = endpoint.basic_auth.unwrap().credentials;
596 assert_eq!(1, creds.len());
597 assert_eq!("ngrok", creds[0].username);
598 assert_eq!("online1line", creds[0].cleartext_password);
599 assert!(creds[0].hashed_password.is_empty());
600
601 let oauth = endpoint.oauth.unwrap();
602 assert_eq!("google", oauth.provider);
603 assert_eq!(["<user>@<domain>"].to_vec(), oauth.allow_emails);
604 assert_eq!(["<domain>"].to_vec(), oauth.allow_domains);
605 assert_eq!(["<scope>"].to_vec(), oauth.scopes);
606 assert_eq!(String::default(), oauth.client_id);
607 assert_eq!(String::default(), *oauth.client_secret);
608 assert!(oauth.sealed_client_secret.is_empty());
609
610 let oidc = endpoint.oidc.unwrap();
611 assert_eq!("<url>", oidc.issuer_url);
612 assert_eq!(["<user>@<domain>"].to_vec(), oidc.allow_emails);
613 assert_eq!(["<domain>"].to_vec(), oidc.allow_domains);
614 assert_eq!(["<scope>"].to_vec(), oidc.scopes);
615 assert_eq!("<id>", oidc.client_id);
616 assert_eq!("<secret>", *oidc.client_secret);
617 assert!(oidc.sealed_client_secret.is_empty());
618
619 let user_agent_filter = endpoint.user_agent_filter.unwrap();
620 assert_eq!(Vec::from([ALLOW_AGENT]), user_agent_filter.allow);
621 assert_eq!(Vec::from([DENY_AGENT]), user_agent_filter.deny);
622 }
623
624 assert_eq!(HashMap::new(), tunnel_cfg.labels());
625 }
626}