1use std::collections::HashMap;
2
3use bytes::Bytes;
4use url::Url;
5
6use super::{
7 common::ProxyProto,
8 Policy,
9};
10#[allow(unused_imports)]
12use crate::config::{
13 ForwarderBuilder,
14 TunnelBuilder,
15};
16use crate::{
17 config::common::{
18 default_forwards_to,
19 Binding,
20 CommonOpts,
21 TunnelConfig,
22 },
23 internals::proto::{
24 self,
25 BindExtra,
26 BindOpts,
27 TlsTermination,
28 },
29 tunnel::TlsTunnel,
30 Session,
31};
32
33#[derive(Default, Clone)]
35struct TlsOptions {
36 pub(crate) common_opts: CommonOpts,
37 pub(crate) domain: Option<String>,
38 pub(crate) mutual_tlsca: Vec<bytes::Bytes>,
39 pub(crate) key_pem: Option<bytes::Bytes>,
40 pub(crate) cert_pem: Option<bytes::Bytes>,
41 pub(crate) bindings: Vec<String>,
42}
43
44impl TunnelConfig for TlsOptions {
45 fn forwards_to(&self) -> String {
46 self.common_opts
47 .forwards_to
48 .clone()
49 .unwrap_or(default_forwards_to().into())
50 }
51
52 fn forwards_proto(&self) -> String {
53 String::new()
55 }
56
57 fn verify_upstream_tls(&self) -> bool {
58 self.common_opts.verify_upstream_tls()
59 }
60
61 fn extra(&self) -> BindExtra {
62 BindExtra {
63 token: Default::default(),
64 ip_policy_ref: Default::default(),
65 metadata: self.common_opts.metadata.clone().unwrap_or_default(),
66 bindings: self.bindings.clone(),
67 pooling_enabled: self.common_opts.pooling_enabled.unwrap_or(false),
68 }
69 }
70 fn proto(&self) -> String {
71 "tls".into()
72 }
73
74 fn opts(&self) -> Option<BindOpts> {
75 let mut tls_endpoint = proto::TlsEndpoint::default();
77
78 if let Some(domain) = self.domain.as_ref() {
79 tls_endpoint.hostname = domain.clone();
81 }
82 tls_endpoint.proxy_proto = self.common_opts.proxy_proto;
83
84 let tls_termination = self
86 .cert_pem
87 .as_ref()
88 .zip(self.key_pem.as_ref())
89 .map(|(c, k)| TlsTermination {
90 cert: c.to_vec(),
91 key: k.to_vec().into(),
92 sealed_key: Vec::new(),
93 });
94
95 tls_endpoint.ip_restriction = self.common_opts.ip_restriction();
96 tls_endpoint.mutual_tls_at_edge =
97 (!self.mutual_tlsca.is_empty()).then_some(self.mutual_tlsca.as_slice().into());
98 tls_endpoint.tls_termination = tls_termination;
99 tls_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() {
100 self.common_opts.traffic_policy.clone().map(From::from)
101 } else if self.common_opts.policy.is_some() {
102 self.common_opts.policy.clone().map(From::from)
103 } else {
104 None
105 };
106 Some(BindOpts::Tls(tls_endpoint))
107 }
108 fn labels(&self) -> HashMap<String, String> {
109 HashMap::new()
110 }
111}
112
113impl_builder! {
114 TlsTunnelBuilder, TlsOptions, TlsTunnel, endpoint
118}
119
120impl TlsTunnelBuilder {
121 pub fn allow_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
125 self.options.common_opts.cidr_restrictions.allow(cidr);
126 self
127 }
128 pub fn deny_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
132 self.options.common_opts.cidr_restrictions.deny(cidr);
133 self
134 }
135 pub fn proxy_proto(&mut self, proxy_proto: ProxyProto) -> &mut Self {
137 self.options.common_opts.proxy_proto = proxy_proto;
138 self
139 }
140 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
144 self.options.common_opts.metadata = Some(metadata.into());
145 self
146 }
147
148 pub fn binding(&mut self, binding: impl Into<String>) -> &mut Self {
178 if !self.options.bindings.is_empty() {
179 panic!("binding() can only be called once");
180 }
181 let binding_str = binding.into();
182 if let Err(e) = Binding::validate(&binding_str) {
183 panic!("{}", e);
184 }
185 self.options.bindings.push(binding_str);
186 self
187 }
188 pub fn forwards_to(&mut self, forwards_to: impl Into<String>) -> &mut Self {
197 self.options.common_opts.forwards_to = Some(forwards_to.into());
198 self
199 }
200
201 pub fn verify_upstream_tls(&mut self, verify_upstream_tls: bool) -> &mut Self {
203 self.options
204 .common_opts
205 .set_verify_upstream_tls(verify_upstream_tls);
206 self
207 }
208
209 pub fn domain(&mut self, domain: impl Into<String>) -> &mut Self {
213 self.options.domain = Some(domain.into());
214 self
215 }
216
217 pub fn mutual_tlsca(&mut self, mutual_tlsca: Bytes) -> &mut Self {
224 self.options.mutual_tlsca.push(mutual_tlsca);
225 self
226 }
227
228 pub fn termination(&mut self, cert_pem: Bytes, key_pem: Bytes) -> &mut Self {
233 self.options.key_pem = Some(key_pem);
234 self.options.cert_pem = Some(cert_pem);
235 self
236 }
237
238 pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
240 where
241 S: TryInto<Policy>,
242 {
243 self.options.common_opts.policy = Some(s.try_into()?);
244 Ok(self)
245 }
246
247 pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
249 self.options.common_opts.traffic_policy = Some(policy_str.into());
250 self
251 }
252
253 pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
254 self.options.common_opts.for_forwarding_to(to_url);
255 self
256 }
257
258 pub fn pooling_enabled(&mut self, pooling_enabled: impl Into<bool>) -> &mut Self {
260 self.options.common_opts.pooling_enabled = Some(pooling_enabled.into());
261 self
262 }
263}
264
265#[cfg(test)]
266mod test {
267 use super::*;
268 use crate::config::policies::test::POLICY_JSON;
269
270 const METADATA: &str = "testmeta";
271 const TEST_FORWARD: &str = "testforward";
272 const ALLOW_CIDR: &str = "0.0.0.0/0";
273 const DENY_CIDR: &str = "10.1.1.1/32";
274 const CA_CERT: &[u8] = "test ca cert".as_bytes();
275 const CA_CERT2: &[u8] = "test ca cert2".as_bytes();
276 const KEY: &[u8] = "test cert".as_bytes();
277 const CERT: &[u8] = "test cert".as_bytes();
278 const DOMAIN: &str = "test domain";
279
280 #[test]
281 fn test_interface_to_proto() {
282 tunnel_test(
285 &TlsTunnelBuilder {
286 session: None,
287 options: Default::default(),
288 }
289 .allow_cidr(ALLOW_CIDR)
290 .deny_cidr(DENY_CIDR)
291 .proxy_proto(ProxyProto::V2)
292 .metadata(METADATA)
293 .domain(DOMAIN)
294 .mutual_tlsca(CA_CERT.into())
295 .mutual_tlsca(CA_CERT2.into())
296 .termination(CERT.into(), KEY.into())
297 .forwards_to(TEST_FORWARD)
298 .policy(POLICY_JSON)
299 .unwrap()
300 .options,
301 );
302 }
303
304 fn tunnel_test<C>(tunnel_cfg: C)
305 where
306 C: TunnelConfig,
307 {
308 assert_eq!(TEST_FORWARD, tunnel_cfg.forwards_to());
309
310 let extra = tunnel_cfg.extra();
311 assert_eq!(String::default(), *extra.token);
312 assert_eq!(METADATA, extra.metadata);
313 assert_eq!(Vec::<String>::new(), extra.bindings);
314 assert_eq!(String::default(), extra.ip_policy_ref);
315
316 assert_eq!("tls", tunnel_cfg.proto());
317
318 let opts = tunnel_cfg.opts().unwrap();
319 assert!(matches!(opts, BindOpts::Tls { .. }));
320 if let BindOpts::Tls(endpoint) = opts {
321 assert_eq!(DOMAIN, endpoint.hostname);
322 assert_eq!(String::default(), endpoint.subdomain);
323 assert!(matches!(endpoint.proxy_proto, ProxyProto::V2));
324 assert!(!endpoint.mutual_tls_at_agent);
325
326 let ip_restriction = endpoint.ip_restriction.unwrap();
327 assert_eq!(Vec::from([ALLOW_CIDR]), ip_restriction.allow_cidrs);
328 assert_eq!(Vec::from([DENY_CIDR]), ip_restriction.deny_cidrs);
329
330 let tls_termination = endpoint.tls_termination.unwrap();
331 assert_eq!(CERT, tls_termination.cert);
332 assert_eq!(KEY, *tls_termination.key);
333 assert!(tls_termination.sealed_key.is_empty());
334
335 let mutual_tls = endpoint.mutual_tls_at_edge.unwrap();
336 let mut agg = CA_CERT.to_vec();
337 agg.extend(CA_CERT2.to_vec());
338 assert_eq!(agg, mutual_tls.mutual_tls_ca);
339 }
340
341 assert_eq!(HashMap::new(), tunnel_cfg.labels());
342 }
343
344 #[test]
345 fn test_binding_valid_values() {
346 let mut builder = TlsTunnelBuilder {
347 session: None,
348 options: Default::default(),
349 };
350
351 builder.binding("public");
353 assert_eq!(vec!["public"], builder.options.bindings);
354
355 let mut builder = TlsTunnelBuilder {
357 session: None,
358 options: Default::default(),
359 };
360 builder.binding("internal");
361 assert_eq!(vec!["internal"], builder.options.bindings);
362
363 let mut builder = TlsTunnelBuilder {
365 session: None,
366 options: Default::default(),
367 };
368 builder.binding("kubernetes");
369 assert_eq!(vec!["kubernetes"], builder.options.bindings);
370
371 let mut builder = TlsTunnelBuilder {
373 session: None,
374 options: Default::default(),
375 };
376 builder.binding(Binding::Kubernetes);
377 assert_eq!(vec!["kubernetes"], builder.options.bindings);
378 }
379
380 #[test]
381 #[should_panic(expected = "Invalid binding value")]
382 fn test_binding_invalid_value() {
383 let mut builder = TlsTunnelBuilder {
384 session: None,
385 options: Default::default(),
386 };
387 builder.binding("invalid");
388 }
389
390 #[test]
391 #[should_panic(expected = "binding() can only be called once")]
392 fn test_binding_called_twice() {
393 let mut builder = TlsTunnelBuilder {
394 session: None,
395 options: Default::default(),
396 };
397 builder.binding("public");
398 builder.binding("internal");
399 }
400}