1use std::collections::HashMap;
2
3use bytes::Bytes;
4use url::Url;
5
6use super::{
7 common::ProxyProto,
8 Policy,
9};
10#[allow(unused_imports)]
12use crate::config::{
13 ForwarderBuilder,
14 TunnelBuilder,
15};
16use crate::{
17 config::common::{
18 default_forwards_to,
19 CommonOpts,
20 TunnelConfig,
21 },
22 internals::proto::{
23 self,
24 BindExtra,
25 BindOpts,
26 TlsTermination,
27 },
28 tunnel::TlsTunnel,
29 Session,
30};
31
32#[derive(Default, Clone)]
34struct TlsOptions {
35 pub(crate) common_opts: CommonOpts,
36 pub(crate) domain: Option<String>,
37 pub(crate) mutual_tlsca: Vec<bytes::Bytes>,
38 pub(crate) key_pem: Option<bytes::Bytes>,
39 pub(crate) cert_pem: Option<bytes::Bytes>,
40 pub(crate) bindings: Vec<String>,
41}
42
43impl TunnelConfig for TlsOptions {
44 fn forwards_to(&self) -> String {
45 self.common_opts
46 .forwards_to
47 .clone()
48 .unwrap_or(default_forwards_to().into())
49 }
50
51 fn forwards_proto(&self) -> String {
52 String::new()
54 }
55
56 fn verify_upstream_tls(&self) -> bool {
57 self.common_opts.verify_upstream_tls()
58 }
59
60 fn extra(&self) -> BindExtra {
61 BindExtra {
62 token: Default::default(),
63 ip_policy_ref: Default::default(),
64 metadata: self.common_opts.metadata.clone().unwrap_or_default(),
65 bindings: self.bindings.clone(),
66 pooling_enabled: self.common_opts.pooling_enabled.unwrap_or(false),
67 }
68 }
69 fn proto(&self) -> String {
70 "tls".into()
71 }
72
73 fn opts(&self) -> Option<BindOpts> {
74 let mut tls_endpoint = proto::TlsEndpoint::default();
76
77 if let Some(domain) = self.domain.as_ref() {
78 tls_endpoint.hostname = domain.clone();
80 }
81 tls_endpoint.proxy_proto = self.common_opts.proxy_proto;
82
83 let tls_termination = self
85 .cert_pem
86 .as_ref()
87 .zip(self.key_pem.as_ref())
88 .map(|(c, k)| TlsTermination {
89 cert: c.to_vec(),
90 key: k.to_vec().into(),
91 sealed_key: Vec::new(),
92 });
93
94 tls_endpoint.ip_restriction = self.common_opts.ip_restriction();
95 tls_endpoint.mutual_tls_at_edge =
96 (!self.mutual_tlsca.is_empty()).then_some(self.mutual_tlsca.as_slice().into());
97 tls_endpoint.tls_termination = tls_termination;
98 tls_endpoint.traffic_policy = if self.common_opts.traffic_policy.is_some() {
99 self.common_opts.traffic_policy.clone().map(From::from)
100 } else if self.common_opts.policy.is_some() {
101 self.common_opts.policy.clone().map(From::from)
102 } else {
103 None
104 };
105 Some(BindOpts::Tls(tls_endpoint))
106 }
107 fn labels(&self) -> HashMap<String, String> {
108 HashMap::new()
109 }
110}
111
112impl_builder! {
113 TlsTunnelBuilder, TlsOptions, TlsTunnel, endpoint
117}
118
119impl TlsTunnelBuilder {
120 pub fn allow_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
124 self.options.common_opts.cidr_restrictions.allow(cidr);
125 self
126 }
127 pub fn deny_cidr(&mut self, cidr: impl Into<String>) -> &mut Self {
131 self.options.common_opts.cidr_restrictions.deny(cidr);
132 self
133 }
134 pub fn proxy_proto(&mut self, proxy_proto: ProxyProto) -> &mut Self {
136 self.options.common_opts.proxy_proto = proxy_proto;
137 self
138 }
139 pub fn metadata(&mut self, metadata: impl Into<String>) -> &mut Self {
143 self.options.common_opts.metadata = Some(metadata.into());
144 self
145 }
146 pub fn binding(&mut self, binding: impl Into<String>) -> &mut Self {
148 self.options.bindings.push(binding.into());
149 self
150 }
151 pub fn forwards_to(&mut self, forwards_to: impl Into<String>) -> &mut Self {
160 self.options.common_opts.forwards_to = Some(forwards_to.into());
161 self
162 }
163
164 pub fn verify_upstream_tls(&mut self, verify_upstream_tls: bool) -> &mut Self {
166 self.options
167 .common_opts
168 .set_verify_upstream_tls(verify_upstream_tls);
169 self
170 }
171
172 pub fn domain(&mut self, domain: impl Into<String>) -> &mut Self {
176 self.options.domain = Some(domain.into());
177 self
178 }
179
180 pub fn mutual_tlsca(&mut self, mutual_tlsca: Bytes) -> &mut Self {
187 self.options.mutual_tlsca.push(mutual_tlsca);
188 self
189 }
190
191 pub fn termination(&mut self, cert_pem: Bytes, key_pem: Bytes) -> &mut Self {
196 self.options.key_pem = Some(key_pem);
197 self.options.cert_pem = Some(cert_pem);
198 self
199 }
200
201 pub fn policy<S>(&mut self, s: S) -> Result<&mut Self, S::Error>
203 where
204 S: TryInto<Policy>,
205 {
206 self.options.common_opts.policy = Some(s.try_into()?);
207 Ok(self)
208 }
209
210 pub fn traffic_policy(&mut self, policy_str: impl Into<String>) -> &mut Self {
212 self.options.common_opts.traffic_policy = Some(policy_str.into());
213 self
214 }
215
216 pub(crate) async fn for_forwarding_to(&mut self, to_url: &Url) -> &mut Self {
217 self.options.common_opts.for_forwarding_to(to_url);
218 self
219 }
220
221 pub fn pooling_enabled(&mut self, pooling_enabled: impl Into<bool>) -> &mut Self {
223 self.options.common_opts.pooling_enabled = Some(pooling_enabled.into());
224 self
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231 use crate::config::policies::test::POLICY_JSON;
232
233 const BINDING: &str = "public";
234 const METADATA: &str = "testmeta";
235 const TEST_FORWARD: &str = "testforward";
236 const ALLOW_CIDR: &str = "0.0.0.0/0";
237 const DENY_CIDR: &str = "10.1.1.1/32";
238 const CA_CERT: &[u8] = "test ca cert".as_bytes();
239 const CA_CERT2: &[u8] = "test ca cert2".as_bytes();
240 const KEY: &[u8] = "test cert".as_bytes();
241 const CERT: &[u8] = "test cert".as_bytes();
242 const DOMAIN: &str = "test domain";
243
244 #[test]
245 fn test_interface_to_proto() {
246 tunnel_test(
249 &TlsTunnelBuilder {
250 session: None,
251 options: Default::default(),
252 }
253 .allow_cidr(ALLOW_CIDR)
254 .deny_cidr(DENY_CIDR)
255 .proxy_proto(ProxyProto::V2)
256 .metadata(METADATA)
257 .binding(BINDING)
258 .domain(DOMAIN)
259 .mutual_tlsca(CA_CERT.into())
260 .mutual_tlsca(CA_CERT2.into())
261 .termination(CERT.into(), KEY.into())
262 .forwards_to(TEST_FORWARD)
263 .policy(POLICY_JSON)
264 .unwrap()
265 .options,
266 );
267 }
268
269 fn tunnel_test<C>(tunnel_cfg: C)
270 where
271 C: TunnelConfig,
272 {
273 assert_eq!(TEST_FORWARD, tunnel_cfg.forwards_to());
274
275 let extra = tunnel_cfg.extra();
276 assert_eq!(String::default(), *extra.token);
277 assert_eq!(METADATA, extra.metadata);
278 assert_eq!(Vec::from([BINDING]), extra.bindings);
279 assert_eq!(String::default(), extra.ip_policy_ref);
280
281 assert_eq!("tls", tunnel_cfg.proto());
282
283 let opts = tunnel_cfg.opts().unwrap();
284 assert!(matches!(opts, BindOpts::Tls { .. }));
285 if let BindOpts::Tls(endpoint) = opts {
286 assert_eq!(DOMAIN, endpoint.hostname);
287 assert_eq!(String::default(), endpoint.subdomain);
288 assert!(matches!(endpoint.proxy_proto, ProxyProto::V2));
289 assert!(!endpoint.mutual_tls_at_agent);
290
291 let ip_restriction = endpoint.ip_restriction.unwrap();
292 assert_eq!(Vec::from([ALLOW_CIDR]), ip_restriction.allow_cidrs);
293 assert_eq!(Vec::from([DENY_CIDR]), ip_restriction.deny_cidrs);
294
295 let tls_termination = endpoint.tls_termination.unwrap();
296 assert_eq!(CERT, tls_termination.cert);
297 assert_eq!(KEY, *tls_termination.key);
298 assert!(tls_termination.sealed_key.is_empty());
299
300 let mutual_tls = endpoint.mutual_tls_at_edge.unwrap();
301 let mut agg = CA_CERT.to_vec();
302 agg.extend(CA_CERT2.to_vec());
303 assert_eq!(agg, mutual_tls.mutual_tls_ca);
304 }
305
306 assert_eq!(HashMap::new(), tunnel_cfg.labels());
307 }
308}