1use std::{
2 collections::HashMap,
3 error,
4 fmt,
5 io,
6 ops::{
7 Deref,
8 DerefMut,
9 },
10 str::FromStr,
11 string::FromUtf8Error,
12 sync::Arc,
13};
14
15use muxado::typed::StreamType;
16use serde::{
17 de::{
18 DeserializeOwned,
19 Visitor,
20 },
21 Deserialize,
22 Serialize,
23 Serializer,
24};
25use thiserror::Error;
26use tokio::io::{
27 AsyncRead,
28 AsyncReadExt,
29};
30use tracing::debug;
31
32pub const AUTH_REQ: StreamType = StreamType::clamp(0);
33pub const BIND_REQ: StreamType = StreamType::clamp(1);
34pub const UNBIND_REQ: StreamType = StreamType::clamp(2);
35pub const PROXY_REQ: StreamType = StreamType::clamp(3);
36pub const RESTART_REQ: StreamType = StreamType::clamp(4);
37pub const STOP_REQ: StreamType = StreamType::clamp(5);
38pub const UPDATE_REQ: StreamType = StreamType::clamp(6);
39pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7);
40pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9);
41
42pub const VERSION: &[&str] = &["3", "2"]; pub trait Error: error::Error {
47 fn error_code(&self) -> Option<&str> {
49 None
50 }
51 fn msg(&self) -> String {
55 format!("{self}")
56 }
57}
58
59impl<E> Error for Box<E>
60where
61 E: Error,
62{
63 fn error_code(&self) -> Option<&str> {
64 <E as Error>::error_code(self)
65 }
66 fn msg(&self) -> String {
67 <E as Error>::msg(self)
68 }
69}
70
71impl<E> Error for Arc<E>
72where
73 E: Error,
74{
75 fn error_code(&self) -> Option<&str> {
76 <E as Error>::error_code(self)
77 }
78 fn msg(&self) -> String {
79 <E as Error>::msg(self)
80 }
81}
82
83impl<E> Error for &E
84where
85 E: Error,
86{
87 fn error_code(&self) -> Option<&str> {
88 <E as Error>::error_code(self)
89 }
90 fn msg(&self) -> String {
91 <E as Error>::msg(self)
92 }
93}
94
95#[derive(Serialize, Deserialize, Debug, Clone, Default)]
96pub struct ErrResp {
97 pub msg: String,
98 pub error_code: Option<String>,
99}
100
101impl<'a> From<&'a str> for ErrResp {
102 fn from(value: &'a str) -> Self {
103 let mut error_code = None;
104 let mut msg_lines = vec![];
105 for line in value.lines().filter(|l| !l.is_empty()) {
106 if line.starts_with("ERR_NGROK_") {
107 error_code = Some(line.trim().into());
108 } else {
109 msg_lines.push(line);
110 }
111 }
112 ErrResp {
113 error_code,
114 msg: msg_lines.join("\n"),
115 }
116 }
117}
118
119impl error::Error for ErrResp {}
120
121const ERR_URL: &str = "https://ngrok.com/docs/errors";
122
123impl fmt::Display for ErrResp {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 self.msg.fmt(f)?;
126 if let Some(code) = self.error_code.as_ref().map(|s| s.to_lowercase()) {
127 write!(f, "\n\n{ERR_URL}/{code}")?;
128 }
129 Ok(())
130 }
131}
132
133impl Error for ErrResp {
134 fn error_code(&self) -> Option<&str> {
135 self.error_code.as_deref()
136 }
137 fn msg(&self) -> String {
138 self.msg.clone()
139 }
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, Default)]
143#[serde(rename_all = "PascalCase")]
144pub struct Auth {
145 pub version: Vec<String>, pub client_id: String, pub extra: AuthExtra, }
149
150#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
151#[serde(transparent)]
152pub struct SecretBytes(#[serde(with = "base64bytes")] Vec<u8>);
153
154impl Deref for SecretBytes {
155 type Target = Vec<u8>;
156 fn deref(&self) -> &Self::Target {
157 &self.0
158 }
159}
160
161impl DerefMut for SecretBytes {
162 fn deref_mut(&mut self) -> &mut Self::Target {
163 &mut self.0
164 }
165}
166
167impl<'a> From<&'a [u8]> for SecretBytes {
168 fn from(other: &'a [u8]) -> Self {
169 SecretBytes(other.into())
170 }
171}
172
173impl From<Vec<u8>> for SecretBytes {
174 fn from(other: Vec<u8>) -> Self {
175 SecretBytes(other)
176 }
177}
178
179impl fmt::Display for SecretBytes {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 write!(f, "********")
182 }
183}
184
185impl fmt::Debug for SecretBytes {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 write!(f, "********")
188 }
189}
190
191#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
192#[serde(transparent)]
193pub struct SecretString(String);
194
195impl Deref for SecretString {
196 type Target = String;
197 fn deref(&self) -> &Self::Target {
198 &self.0
199 }
200}
201
202impl DerefMut for SecretString {
203 fn deref_mut(&mut self) -> &mut Self::Target {
204 &mut self.0
205 }
206}
207
208impl<'a> From<&'a str> for SecretString {
209 fn from(other: &'a str) -> Self {
210 SecretString(other.into())
211 }
212}
213
214impl From<String> for SecretString {
215 fn from(other: String) -> Self {
216 SecretString(other)
217 }
218}
219
220impl fmt::Display for SecretString {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 write!(f, "********")
223 }
224}
225
226impl fmt::Debug for SecretString {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 write!(f, "********")
229 }
230}
231
232#[derive(Serialize, Deserialize, Debug, Clone, Default)]
233#[serde(rename_all = "PascalCase")]
234pub struct AuthExtra {
235 #[serde(rename = "OS")]
236 pub os: String,
237 pub arch: String,
238 pub auth_token: SecretString,
239 pub version: String,
240 pub hostname: String,
241 pub user_agent: String,
242 pub metadata: String,
243 pub cookie: SecretString,
244 pub heartbeat_interval: i64,
245 pub heartbeat_tolerance: i64,
246
247 pub update_unsupported_error: Option<String>,
259 pub stop_unsupported_error: Option<String>,
260 pub restart_unsupported_error: Option<String>,
261
262 pub proxy_type: String,
263 #[serde(rename = "MutualTLS")]
264 pub mutual_tls: bool,
265 pub service_run: bool,
266 pub config_version: String,
267 pub custom_interface: bool,
268 #[serde(rename = "CustomCAs")]
269 pub custom_cas: bool,
270
271 pub client_type: String,
272}
273
274#[derive(Serialize, Deserialize, Debug, Clone)]
275#[serde(rename_all = "PascalCase")]
276pub struct AuthResp {
277 pub version: String,
278 pub client_id: String,
279 #[serde(default)]
280 pub extra: AuthRespExtra,
281}
282
283rpc_req!(Auth, AuthResp, AUTH_REQ);
284
285#[derive(Serialize, Deserialize, Debug, Clone, Default)]
286#[serde(rename_all = "PascalCase")]
287pub struct AuthRespExtra {
288 pub version: Option<String>,
289 pub region: Option<String>,
290 pub cookie: Option<SecretString>,
291 pub account_name: Option<String>,
292 pub session_duration: Option<i64>,
293 pub plan_name: Option<String>,
294 pub banner: Option<String>,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone)]
298#[serde(rename_all = "PascalCase")]
299pub struct Bind<T> {
300 #[serde(rename = "Id")]
301 pub client_id: String,
302 pub proto: String,
303 pub forwards_to: String,
304 pub forwards_proto: String,
305 pub opts: T,
306 pub extra: BindExtra,
307}
308
309#[derive(Debug, Clone)]
310#[allow(clippy::large_enum_variant)]
312pub enum BindOpts {
313 Http(HttpEndpoint),
314 Tcp(TcpEndpoint),
315 Tls(TlsEndpoint),
316}
317
318#[derive(Serialize, Deserialize, Debug, Clone, Default)]
319#[serde(rename_all = "PascalCase")]
320pub struct BindExtra {
321 pub token: SecretString,
322 #[serde(rename = "IPPolicyRef")]
323 pub ip_policy_ref: String,
324 pub metadata: String,
325 pub bindings: Vec<String>,
326 #[serde(rename = "PoolingEnabled")]
327 pub pooling_enabled: bool,
328}
329
330#[derive(Serialize, Deserialize, Debug, Clone)]
331#[serde(rename_all = "PascalCase")]
332pub struct BindResp<T> {
333 #[serde(rename = "Id")]
334 pub client_id: String,
335 #[serde(rename = "URL")]
336 pub url: String,
337 pub proto: String,
338 #[serde(rename = "Opts")]
339 pub bind_opts: T,
340 pub extra: BindRespExtra,
341}
342
343#[derive(Serialize, Deserialize, Debug, Clone)]
344#[serde(rename_all = "PascalCase")]
345pub struct BindRespExtra {
346 pub token: SecretString,
347}
348
349rpc_req!(Bind<T>, BindResp<T>, BIND_REQ; T: std::fmt::Debug + Serialize + DeserializeOwned + Clone);
350
351#[derive(Serialize, Deserialize, Debug, Clone)]
352#[serde(rename_all = "PascalCase")]
353pub struct StartTunnelWithLabel {
354 pub labels: HashMap<String, String>,
355 pub forwards_to: String,
356 pub forwards_proto: String,
357 pub metadata: String,
358}
359
360#[derive(Serialize, Deserialize, Debug, Clone)]
361#[serde(rename_all = "PascalCase")]
362pub struct StartTunnelWithLabelResp {
363 pub id: String,
364}
365
366rpc_req!(
367 StartTunnelWithLabel,
368 StartTunnelWithLabelResp,
369 BIND_LABELED_REQ
370);
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373#[serde(rename_all = "PascalCase")]
374pub struct Unbind {
375 #[serde(rename = "Id")]
376 pub client_id: String,
377 }
379
380#[derive(Serialize, Deserialize, Debug, Clone)]
381#[serde(rename_all = "PascalCase")]
382pub struct UnbindResp {
383 }
385
386rpc_req!(Unbind, UnbindResp, UNBIND_REQ);
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "PascalCase")]
390pub struct ProxyHeader {
391 pub id: String,
392 pub client_addr: String,
393 pub proto: String,
394 pub edge_type: EdgeType,
395 #[serde(rename = "PassthroughTLS")]
396 pub passthrough_tls: bool,
397}
398
399#[derive(Error, Debug)]
400#[non_exhaustive]
401pub enum ReadHeaderError {
402 #[error("error reading proxy header")]
403 Io(#[from] io::Error),
404 #[error("invalid utf-8 in proxy header")]
405 InvalidUtf8(#[from] FromUtf8Error),
406 #[error("invalid proxy header json")]
407 InvalidHeader(#[from] serde_json::Error),
408}
409
410impl ProxyHeader {
411 pub async fn read_from_stream(
412 mut stream: impl AsyncRead + Unpin,
413 ) -> Result<Self, ReadHeaderError> {
414 let size = stream.read_i64_le().await?;
415 let mut buf = vec![0u8; size as usize];
416
417 stream.read_exact(&mut buf).await?;
418
419 let header = String::from_utf8(buf)?;
420
421 debug!(?header, "read header");
422
423 Ok(serde_json::from_str(&header)?)
424 }
425}
426
427#[derive(Copy, Clone, Debug, PartialEq, Eq)]
429pub enum EdgeType {
430 Undefined,
432 Tcp,
434 Tls,
436 Https,
438}
439
440impl FromStr for EdgeType {
441 type Err = ();
442 fn from_str(s: &str) -> Result<Self, Self::Err> {
443 Ok(match s {
444 "1" => EdgeType::Tcp,
445 "2" => EdgeType::Tls,
446 "3" => EdgeType::Https,
447 _ => EdgeType::Undefined,
448 })
449 }
450}
451
452impl EdgeType {
453 pub(crate) fn as_str(self) -> &'static str {
454 match self {
455 EdgeType::Undefined => "0",
456 EdgeType::Tcp => "1",
457 EdgeType::Tls => "2",
458 EdgeType::Https => "3",
459 }
460 }
461}
462
463impl Serialize for EdgeType {
464 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
465 where
466 S: serde::Serializer,
467 {
468 serializer.serialize_str(self.as_str())
469 }
470}
471
472struct EdgeTypeVisitor;
473
474impl<'de> Visitor<'de> for EdgeTypeVisitor {
475 type Value = EdgeType;
476 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
477 formatter.write_str(r#""0", "1", "2", or "3""#)
478 }
479
480 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
481 where
482 E: serde::de::Error,
483 {
484 Ok(EdgeType::from_str(v).unwrap())
485 }
486}
487
488impl<'de> Deserialize<'de> for EdgeType {
489 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
490 where
491 D: serde::Deserializer<'de>,
492 {
493 deserializer.deserialize_str(EdgeTypeVisitor)
494 }
495}
496
497#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
499#[serde(rename_all = "PascalCase")]
500pub struct Stop {}
501
502#[derive(Serialize, Deserialize, Debug, Clone, Default)]
505#[serde(rename_all = "PascalCase")]
506pub struct CommandResp {
507 #[serde(default, skip_serializing_if = "Option::is_none")]
509 pub error: Option<String>,
510}
511
512pub type StopResp = CommandResp;
513
514rpc_req!(Stop, StopResp, STOP_REQ);
515
516#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
518#[serde(rename_all = "PascalCase")]
519pub struct Restart {}
520
521pub type RestartResp = CommandResp;
522rpc_req!(Restart, RestartResp, RESTART_REQ);
523
524#[derive(Serialize, Deserialize, Debug, Clone, Default)]
526#[serde(rename_all = "PascalCase")]
527pub struct Update {
528 pub version: String,
530 pub permit_major_version: bool,
532}
533
534#[derive(Serialize, Deserialize, Debug, Clone, Default)]
536#[serde(rename_all = "PascalCase")]
537pub struct StopTunnel {
538 #[serde(rename = "Id")]
540 pub client_id: String,
541 pub message: String,
543 pub error_code: Option<String>,
545}
546
547pub type UpdateResp = CommandResp;
548rpc_req!(Update, UpdateResp, UPDATE_REQ);
549
550#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
555pub enum ProxyProto {
556 #[default]
558 None,
559 V1,
561 V2,
563}
564
565impl From<ProxyProto> for i64 {
566 fn from(other: ProxyProto) -> Self {
567 use ProxyProto::*;
568 match other {
569 None => 0,
570 V1 => 1,
571 V2 => 2,
572 }
573 }
574}
575
576impl From<i64> for ProxyProto {
577 fn from(other: i64) -> Self {
578 use ProxyProto::*;
579 match other {
580 1 => V1,
581 2 => V2,
582 _ => None,
583 }
584 }
585}
586
587#[derive(Debug, Clone, Error)]
588#[error("invalid proxyproto string: {}", .0)]
589pub struct InvalidProxyProtoString(String);
590
591impl FromStr for ProxyProto {
592 type Err = InvalidProxyProtoString;
593 fn from_str(s: &str) -> Result<Self, Self::Err> {
594 use ProxyProto::*;
595 Ok(match s {
596 "" => None,
597 "1" => V1,
598 "2" => V2,
599 _ => return Err(InvalidProxyProtoString(s.into())),
600 })
601 }
602}
603
604impl Serialize for ProxyProto {
605 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
606 where
607 S: serde::Serializer,
608 {
609 serializer.serialize_i64(i64::from(*self))
610 }
611}
612
613struct ProxyProtoVisitor;
614
615impl<'de> Visitor<'de> for ProxyProtoVisitor {
616 type Value = ProxyProto;
617 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
618 formatter.write_str("0, 1, or 2")
619 }
620
621 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
622 where
623 E: serde::de::Error,
624 {
625 Ok(ProxyProto::from(v))
626 }
627
628 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
629 where
630 E: serde::de::Error,
631 {
632 Ok(ProxyProto::from(v as i64))
633 }
634}
635
636impl<'de> Deserialize<'de> for ProxyProto {
637 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
638 where
639 D: serde::Deserializer<'de>,
640 {
641 deserializer.deserialize_i64(ProxyProtoVisitor)
642 }
643}
644
645#[derive(Debug, Serialize, Deserialize, Clone)]
646#[serde(untagged)]
647pub enum PolicyWrapper {
648 #[serde(serialize_with = "serialize_policy")]
649 Policy(Policy),
650 String(String),
651}
652
653impl From<String> for PolicyWrapper {
654 fn from(value: String) -> Self {
655 PolicyWrapper::String(value)
656 }
657}
658
659#[derive(Serialize, Deserialize, Debug, Clone, Default)]
660#[serde(rename_all = "PascalCase")]
661pub struct HttpEndpoint {
662 #[serde(default)]
663 pub domain: String,
664 pub hostname: String,
665 pub auth: String,
666 pub subdomain: String,
667 pub host_header_rewrite: bool,
668 pub local_url_scheme: Option<String>,
669 pub proxy_proto: ProxyProto,
670
671 pub compression: Option<Compression>,
672 pub circuit_breaker: Option<CircuitBreaker>,
673 #[serde(rename = "IPRestriction")]
674 pub ip_restriction: Option<IpRestriction>,
675 pub basic_auth: Option<BasicAuth>,
676 #[serde(rename = "OAuth")]
677 pub oauth: Option<Oauth>,
678 #[serde(rename = "OIDC")]
679 pub oidc: Option<Oidc>,
680 pub webhook_verification: Option<WebhookVerification>,
681 #[serde(rename = "MutualTLSCA")]
682 pub mutual_tls_ca: Option<MutualTls>,
683 #[serde(default)]
684 pub request_headers: Option<Headers>,
685 #[serde(default)]
686 pub response_headers: Option<Headers>,
687 #[serde(rename = "WebsocketTCPConverter")]
688 pub websocket_tcp_converter: Option<WebsocketTcpConverter>,
689 #[serde(rename = "UserAgentFilter")]
690 pub user_agent_filter: Option<UserAgentFilter>,
691 #[serde(rename = "TrafficPolicy")]
692 pub traffic_policy: Option<PolicyWrapper>,
693}
694
695#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
696pub struct Compression {}
697
698fn is_default<T>(v: &T) -> bool
699where
700 T: PartialEq<T> + Default,
701{
702 T::default() == *v
703}
704
705#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
706pub struct CircuitBreaker {
707 #[serde(default, skip_serializing_if = "is_default")]
708 pub error_threshold: f64,
709}
710
711#[derive(Debug, Clone, Serialize, Deserialize)]
712pub struct BasicAuth {
713 #[serde(default, skip_serializing_if = "is_default")]
714 pub credentials: Vec<BasicAuthCredential>,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
718pub struct BasicAuthCredential {
719 pub username: String,
720 #[serde(default, skip_serializing_if = "is_default")]
721 pub cleartext_password: String,
722 #[serde(default, skip_serializing_if = "is_default")]
723 #[serde(with = "base64bytes")]
724 pub hashed_password: Vec<u8>,
725}
726
727#[derive(Debug, Clone, Serialize, Deserialize)]
728pub struct IpRestriction {
729 #[serde(default, skip_serializing_if = "is_default")]
730 pub allow_cidrs: Vec<String>,
731 #[serde(default, skip_serializing_if = "is_default")]
732 pub deny_cidrs: Vec<String>,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
736pub struct Oauth {
737 pub provider: String,
738 #[serde(default, skip_serializing_if = "is_default")]
739 pub client_id: String,
740 #[serde(default, skip_serializing_if = "is_default")]
741 pub client_secret: SecretString,
742 #[serde(default, skip_serializing_if = "is_default")]
743 #[serde(with = "base64bytes")]
744 pub sealed_client_secret: Vec<u8>,
745 #[serde(default, skip_serializing_if = "is_default")]
746 pub allow_emails: Vec<String>,
747 #[serde(default, skip_serializing_if = "is_default")]
748 pub allow_domains: Vec<String>,
749 #[serde(default, skip_serializing_if = "is_default")]
750 pub scopes: Vec<String>,
751}
752
753#[derive(Debug, Clone, Serialize, Deserialize)]
754pub struct Oidc {
755 pub issuer_url: String,
756 #[serde(default, skip_serializing_if = "is_default")]
757 pub client_id: String,
758 #[serde(default, skip_serializing_if = "is_default")]
759 pub client_secret: SecretString,
760 #[serde(default, skip_serializing_if = "is_default")]
761 #[serde(with = "base64bytes")]
762 pub sealed_client_secret: Vec<u8>,
763 #[serde(default, skip_serializing_if = "is_default")]
764 pub allow_emails: Vec<String>,
765 #[serde(default, skip_serializing_if = "is_default")]
766 pub allow_domains: Vec<String>,
767 #[serde(default, skip_serializing_if = "is_default")]
768 pub scopes: Vec<String>,
769}
770
771#[derive(Debug, Clone, Serialize, Deserialize)]
772pub struct WebhookVerification {
773 pub provider: String,
774 #[serde(default, skip_serializing_if = "is_default")]
775 pub secret: SecretString,
776 #[serde(default, skip_serializing_if = "is_default")]
777 #[serde(with = "base64bytes")]
778 pub sealed_secret: Vec<u8>,
779}
780
781#[derive(Debug, Clone, Serialize, Deserialize)]
782pub struct MutualTls {
783 #[serde(default, skip_serializing_if = "is_default")]
784 #[serde(with = "base64bytes")]
785 pub mutual_tls_ca: Vec<u8>,
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
790#[serde(rename_all = "camelCase")]
791pub struct Headers {
792 #[serde(default, skip_serializing_if = "is_default")]
793 pub add: Vec<String>,
794 #[serde(default, skip_serializing_if = "is_default")]
795 pub remove: Vec<String>,
796 #[serde(default, skip_serializing_if = "is_default")]
797 pub add_parsed: HashMap<String, String>,
798}
799
800#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
801pub struct WebsocketTcpConverter {}
802
803#[derive(Debug, Clone, Serialize, Deserialize)]
804pub struct UserAgentFilter {
805 #[serde(default, skip_serializing_if = "is_default")]
806 pub allow: Vec<String>,
807 #[serde(default, skip_serializing_if = "is_default")]
808 pub deny: Vec<String>,
809}
810
811#[derive(Serialize, Deserialize, Debug, Clone, Default)]
812#[serde(rename_all = "PascalCase")]
813pub struct TcpEndpoint {
814 pub addr: String,
815 pub proxy_proto: ProxyProto,
816 #[serde(rename = "IPRestriction")]
817 pub ip_restriction: Option<IpRestriction>,
818 #[serde(rename = "TrafficPolicy")]
819 pub traffic_policy: Option<PolicyWrapper>,
820}
821
822#[derive(Serialize, Deserialize, Debug, Clone, Default)]
823#[serde(rename_all = "PascalCase")]
824pub struct TlsEndpoint {
825 #[serde(default)]
826 pub domain: String,
827 pub hostname: String,
828 pub subdomain: String,
829 pub proxy_proto: ProxyProto,
830 #[serde(rename = "MutualTLSAtAgent")]
831 pub mutual_tls_at_agent: bool,
832
833 #[serde(rename = "MutualTLSAtEdge")]
834 pub mutual_tls_at_edge: Option<MutualTls>,
835 #[serde(rename = "TLSTermination")]
836 pub tls_termination: Option<TlsTermination>,
837 #[serde(rename = "IPRestriction")]
838 pub ip_restriction: Option<IpRestriction>,
839 #[serde(rename = "TrafficPolicy")]
840 pub traffic_policy: Option<PolicyWrapper>,
841}
842
843#[derive(Serialize, Deserialize, Debug, Clone, Default)]
844pub struct TlsTermination {
845 #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
846 pub cert: Vec<u8>,
847 #[serde(skip_serializing_if = "is_default", default)]
848 pub key: SecretBytes,
849 #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
850 pub sealed_key: Vec<u8>,
851}
852
853#[derive(Serialize, Deserialize, Debug, Clone, Default)]
854#[serde(rename_all = "PascalCase", default)]
855pub struct Policy {
856 pub inbound: Vec<Rule>,
857 pub outbound: Vec<Rule>,
858}
859
860#[derive(Serialize, Deserialize, Debug, Clone, Default)]
861#[serde(rename_all = "PascalCase", default)]
862pub struct Rule {
863 pub name: String,
864 pub expressions: Vec<String>,
865 pub actions: Vec<Action>,
866}
867
868#[derive(Serialize, Deserialize, Debug, Clone, Default)]
869#[serde(rename_all = "PascalCase", default)]
870pub struct Action {
871 #[serde(rename = "Type")]
872 pub type_: String,
873 #[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")]
874 pub config: Vec<u8>,
875}
876
877fn serialize_policy<S: Serializer>(v: &Policy, s: S) -> Result<S::Ok, S::Error> {
880 let abc = match serde_json::to_string(v) {
881 Ok(t) => t,
882 Err(_) => {
883 return Err(serde::ser::Error::custom(
884 "policy could not be converted to valid json",
885 ))
886 }
887 };
888 s.serialize_str(&abc)
889}
890
891mod vec_to_json {
894 use serde::{
895 Deserialize,
896 Deserializer,
897 Serialize,
898 Serializer,
899 };
900
901 pub fn serialize<S: Serializer>(v: &[u8], s: S) -> Result<S::Ok, S::Error> {
902 let u: serde_json::Value = match serde_json::from_slice(v) {
903 Ok(k) => k,
904 Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")),
905 };
906
907 u.serialize(s)
908 }
909
910 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
911 let s = serde_json::Map::deserialize(d)?;
912 let v = serde_json::to_vec(&s).unwrap();
913 Ok(v)
914 }
915}
916
917mod base64bytes {
920 use base64::prelude::*;
921 use serde::{
922 Deserialize,
923 Deserializer,
924 Serialize,
925 Serializer,
926 };
927
928 pub fn serialize<S: Serializer>(v: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
929 BASE64_STANDARD.encode(v).serialize(s)
930 }
931
932 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
933 let s = String::deserialize(d)?;
934 BASE64_STANDARD
935 .decode(s.as_bytes())
936 .map_err(serde::de::Error::custom)
937 }
938}
939
940#[cfg(test)]
941mod test {
942
943 use super::*;
944
945 #[test]
946 fn test_proxy_proto_serde() {
947 let input = "2";
948
949 let p: ProxyProto = serde_json::from_str(input).unwrap();
950
951 assert!(matches!(p, ProxyProto::V2));
952
953 assert_eq!(serde_json::to_string(&p).unwrap(), "2");
954 }
955
956 pub(crate) const POLICY_JSON: &str = r###"{"Inbound":[{"Name":"test_in","Expressions":["req.Method == 'PUT'"],"Actions":[{"Type":"deny"}]}],"Outbound":[{"Name":"test_out","Expressions":["res.StatusCode == '200'"],"Actions":[{"Type":"custom-response","Config":{"status_code":201}}]}]}"###;
957
958 #[test]
959 fn test_policy_proto_serde() {
960 let policy: Policy = serde_json::from_str(POLICY_JSON).unwrap();
961
962 assert_eq!(1, policy.outbound.len());
965 let outbound = &policy.outbound[0];
966 assert_eq!(1, outbound.actions.len());
967 let action = &outbound.actions[0];
968 assert_eq!(r#"{"status_code":201}"#.as_bytes(), action.config);
969
970 assert_eq!(serde_json::to_string(&policy).unwrap(), POLICY_JSON);
971 }
972}