1use std::{
2 collections::HashMap,
3 error,
4 fmt,
5 io,
6 ops::{
7 Deref,
8 DerefMut,
9 },
10 str::FromStr,
11 string::FromUtf8Error,
12 sync::Arc,
13};
14
15use muxado::typed::StreamType;
16use serde::{
17 de::{
18 DeserializeOwned,
19 Visitor,
20 },
21 Deserialize,
22 Serialize,
23 Serializer,
24};
25use thiserror::Error;
26use tokio::io::{
27 AsyncRead,
28 AsyncReadExt,
29};
30use tracing::debug;
31
32pub const AUTH_REQ: StreamType = StreamType::clamp(0);
33pub const BIND_REQ: StreamType = StreamType::clamp(1);
34pub const UNBIND_REQ: StreamType = StreamType::clamp(2);
35pub const PROXY_REQ: StreamType = StreamType::clamp(3);
36pub const RESTART_REQ: StreamType = StreamType::clamp(4);
37pub const STOP_REQ: StreamType = StreamType::clamp(5);
38pub const UPDATE_REQ: StreamType = StreamType::clamp(6);
39pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7);
40pub const SRV_INFO_REQ: StreamType = StreamType::clamp(8);
41pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9);
42
43pub const VERSION: &[&str] = &["3", "2"]; pub trait Error: error::Error {
48 fn error_code(&self) -> Option<&str> {
50 None
51 }
52 fn msg(&self) -> String {
56 format!("{self}")
57 }
58}
59
60impl<E> Error for Box<E>
61where
62 E: Error,
63{
64 fn error_code(&self) -> Option<&str> {
65 <E as Error>::error_code(self)
66 }
67 fn msg(&self) -> String {
68 <E as Error>::msg(self)
69 }
70}
71
72impl<E> Error for Arc<E>
73where
74 E: Error,
75{
76 fn error_code(&self) -> Option<&str> {
77 <E as Error>::error_code(self)
78 }
79 fn msg(&self) -> String {
80 <E as Error>::msg(self)
81 }
82}
83
84impl<E> Error for &E
85where
86 E: Error,
87{
88 fn error_code(&self) -> Option<&str> {
89 <E as Error>::error_code(self)
90 }
91 fn msg(&self) -> String {
92 <E as Error>::msg(self)
93 }
94}
95
96#[derive(Serialize, Deserialize, Debug, Clone, Default)]
97pub struct ErrResp {
98 pub msg: String,
99 pub error_code: Option<String>,
100}
101
102impl<'a> From<&'a str> for ErrResp {
103 fn from(value: &'a str) -> Self {
104 let mut error_code = None;
105 let mut msg_lines = vec![];
106 for line in value.lines().filter(|l| !l.is_empty()) {
107 if line.starts_with("ERR_NGROK_") {
108 error_code = Some(line.trim().into());
109 } else {
110 msg_lines.push(line);
111 }
112 }
113 ErrResp {
114 error_code,
115 msg: msg_lines.join("\n"),
116 }
117 }
118}
119
120impl error::Error for ErrResp {}
121
122const ERR_URL: &str = "https://ngrok.com/docs/errors";
123
124impl fmt::Display for ErrResp {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 self.msg.fmt(f)?;
127 if let Some(code) = self.error_code.as_ref().map(|s| s.to_lowercase()) {
128 write!(f, "\n\n{ERR_URL}/{code}")?;
129 }
130 Ok(())
131 }
132}
133
134impl Error for ErrResp {
135 fn error_code(&self) -> Option<&str> {
136 self.error_code.as_deref()
137 }
138 fn msg(&self) -> String {
139 self.msg.clone()
140 }
141}
142
143#[derive(Serialize, Deserialize, Debug, Clone, Default)]
144#[serde(rename_all = "PascalCase")]
145pub struct Auth {
146 pub version: Vec<String>, pub client_id: String, pub extra: AuthExtra, }
150
151#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
152#[serde(transparent)]
153pub struct SecretBytes(#[serde(with = "base64bytes")] Vec<u8>);
154
155impl Deref for SecretBytes {
156 type Target = Vec<u8>;
157 fn deref(&self) -> &Self::Target {
158 &self.0
159 }
160}
161
162impl DerefMut for SecretBytes {
163 fn deref_mut(&mut self) -> &mut Self::Target {
164 &mut self.0
165 }
166}
167
168impl<'a> From<&'a [u8]> for SecretBytes {
169 fn from(other: &'a [u8]) -> Self {
170 SecretBytes(other.into())
171 }
172}
173
174impl From<Vec<u8>> for SecretBytes {
175 fn from(other: Vec<u8>) -> Self {
176 SecretBytes(other)
177 }
178}
179
180impl fmt::Display for SecretBytes {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 write!(f, "********")
183 }
184}
185
186impl fmt::Debug for SecretBytes {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 write!(f, "********")
189 }
190}
191
192#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
193#[serde(transparent)]
194pub struct SecretString(String);
195
196impl Deref for SecretString {
197 type Target = String;
198 fn deref(&self) -> &Self::Target {
199 &self.0
200 }
201}
202
203impl DerefMut for SecretString {
204 fn deref_mut(&mut self) -> &mut Self::Target {
205 &mut self.0
206 }
207}
208
209impl<'a> From<&'a str> for SecretString {
210 fn from(other: &'a str) -> Self {
211 SecretString(other.into())
212 }
213}
214
215impl From<String> for SecretString {
216 fn from(other: String) -> Self {
217 SecretString(other)
218 }
219}
220
221impl fmt::Display for SecretString {
222 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223 write!(f, "********")
224 }
225}
226
227impl fmt::Debug for SecretString {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 write!(f, "********")
230 }
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone, Default)]
234#[serde(rename_all = "PascalCase")]
235pub struct AuthExtra {
236 #[serde(rename = "OS")]
237 pub os: String,
238 pub arch: String,
239 pub auth_token: SecretString,
240 pub version: String,
241 pub hostname: String,
242 pub user_agent: String,
243 pub metadata: String,
244 pub cookie: SecretString,
245 pub heartbeat_interval: i64,
246 pub heartbeat_tolerance: i64,
247
248 pub update_unsupported_error: Option<String>,
260 pub stop_unsupported_error: Option<String>,
261 pub restart_unsupported_error: Option<String>,
262
263 pub proxy_type: String,
264 #[serde(rename = "MutualTLS")]
265 pub mutual_tls: bool,
266 pub service_run: bool,
267 pub config_version: String,
268 pub custom_interface: bool,
269 #[serde(rename = "CustomCAs")]
270 pub custom_cas: bool,
271
272 pub client_type: String,
273}
274
275#[derive(Serialize, Deserialize, Debug, Clone)]
276#[serde(rename_all = "PascalCase")]
277pub struct AuthResp {
278 pub version: String,
279 pub client_id: String,
280 #[serde(default)]
281 pub extra: AuthRespExtra,
282}
283
284rpc_req!(Auth, AuthResp, AUTH_REQ);
285
286#[derive(Serialize, Deserialize, Debug, Clone, Default)]
287#[serde(rename_all = "PascalCase")]
288pub struct AuthRespExtra {
289 pub version: Option<String>,
290 pub region: Option<String>,
291 pub cookie: Option<SecretString>,
292 pub account_name: Option<String>,
293 pub session_duration: Option<i64>,
294 pub plan_name: Option<String>,
295 pub banner: Option<String>,
296}
297
298#[derive(Serialize, Deserialize, Debug, Clone)]
299#[serde(rename_all = "PascalCase")]
300pub struct Bind<T> {
301 #[serde(rename = "Id")]
302 pub client_id: String,
303 pub proto: String,
304 pub forwards_to: String,
305 pub forwards_proto: String,
306 pub opts: T,
307 pub extra: BindExtra,
308}
309
310#[derive(Debug, Clone)]
311#[allow(clippy::large_enum_variant)]
313pub enum BindOpts {
314 Http(HttpEndpoint),
315 Tcp(TcpEndpoint),
316 Tls(TlsEndpoint),
317}
318
319#[derive(Serialize, Deserialize, Debug, Clone, Default)]
320#[serde(rename_all = "PascalCase")]
321pub struct BindExtra {
322 pub token: SecretString,
323 #[serde(rename = "IPPolicyRef")]
324 pub ip_policy_ref: String,
325 pub metadata: String,
326 pub bindings: Vec<String>,
327 #[serde(rename = "PoolingEnabled")]
328 pub pooling_enabled: bool,
329}
330
331#[derive(Serialize, Deserialize, Debug, Clone)]
332#[serde(rename_all = "PascalCase")]
333pub struct BindResp<T> {
334 #[serde(rename = "Id")]
335 pub client_id: String,
336 #[serde(rename = "URL")]
337 pub url: String,
338 pub proto: String,
339 #[serde(rename = "Opts")]
340 pub bind_opts: T,
341 pub extra: BindRespExtra,
342}
343
344#[derive(Serialize, Deserialize, Debug, Clone)]
345#[serde(rename_all = "PascalCase")]
346pub struct BindRespExtra {
347 pub token: SecretString,
348}
349
350rpc_req!(Bind<T>, BindResp<T>, BIND_REQ; T: std::fmt::Debug + Serialize + DeserializeOwned + Clone);
351
352#[derive(Serialize, Deserialize, Debug, Clone)]
353#[serde(rename_all = "PascalCase")]
354pub struct StartTunnelWithLabel {
355 pub labels: HashMap<String, String>,
356 pub forwards_to: String,
357 pub forwards_proto: String,
358 pub metadata: String,
359}
360
361#[derive(Serialize, Deserialize, Debug, Clone)]
362#[serde(rename_all = "PascalCase")]
363pub struct StartTunnelWithLabelResp {
364 pub id: String,
365}
366
367rpc_req!(
368 StartTunnelWithLabel,
369 StartTunnelWithLabelResp,
370 BIND_LABELED_REQ
371);
372
373#[derive(Serialize, Deserialize, Debug, Clone)]
374#[serde(rename_all = "PascalCase")]
375pub struct Unbind {
376 #[serde(rename = "Id")]
377 pub client_id: String,
378 }
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382#[serde(rename_all = "PascalCase")]
383pub struct UnbindResp {
384 }
386
387rpc_req!(Unbind, UnbindResp, UNBIND_REQ);
388
389#[derive(Serialize, Deserialize, Debug, Clone)]
390#[serde(rename_all = "PascalCase")]
391pub struct ProxyHeader {
392 pub id: String,
393 pub client_addr: String,
394 pub proto: String,
395 pub edge_type: EdgeType,
396 #[serde(rename = "PassthroughTLS")]
397 pub passthrough_tls: bool,
398}
399
400#[derive(Error, Debug)]
401#[non_exhaustive]
402pub enum ReadHeaderError {
403 #[error("error reading proxy header")]
404 Io(#[from] io::Error),
405 #[error("invalid utf-8 in proxy header")]
406 InvalidUtf8(#[from] FromUtf8Error),
407 #[error("invalid proxy header json")]
408 InvalidHeader(#[from] serde_json::Error),
409}
410
411impl ProxyHeader {
412 pub async fn read_from_stream(
413 mut stream: impl AsyncRead + Unpin,
414 ) -> Result<Self, ReadHeaderError> {
415 let size = stream.read_i64_le().await?;
416 let mut buf = vec![0u8; size as usize];
417
418 stream.read_exact(&mut buf).await?;
419
420 let header = String::from_utf8(buf)?;
421
422 debug!(?header, "read header");
423
424 Ok(serde_json::from_str(&header)?)
425 }
426}
427
428#[derive(Copy, Clone, Debug, PartialEq, Eq)]
430pub enum EdgeType {
431 Undefined,
433 Tcp,
435 Tls,
437 Https,
439}
440
441impl FromStr for EdgeType {
442 type Err = ();
443 fn from_str(s: &str) -> Result<Self, Self::Err> {
444 Ok(match s {
445 "1" => EdgeType::Tcp,
446 "2" => EdgeType::Tls,
447 "3" => EdgeType::Https,
448 _ => EdgeType::Undefined,
449 })
450 }
451}
452
453impl EdgeType {
454 pub(crate) fn as_str(self) -> &'static str {
455 match self {
456 EdgeType::Undefined => "0",
457 EdgeType::Tcp => "1",
458 EdgeType::Tls => "2",
459 EdgeType::Https => "3",
460 }
461 }
462}
463
464impl Serialize for EdgeType {
465 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
466 where
467 S: serde::Serializer,
468 {
469 serializer.serialize_str(self.as_str())
470 }
471}
472
473struct EdgeTypeVisitor;
474
475impl<'de> Visitor<'de> for EdgeTypeVisitor {
476 type Value = EdgeType;
477 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
478 formatter.write_str(r#""0", "1", "2", or "3""#)
479 }
480
481 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
482 where
483 E: serde::de::Error,
484 {
485 Ok(EdgeType::from_str(v).unwrap())
486 }
487}
488
489impl<'de> Deserialize<'de> for EdgeType {
490 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
491 where
492 D: serde::Deserializer<'de>,
493 {
494 deserializer.deserialize_str(EdgeTypeVisitor)
495 }
496}
497
498#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
500#[serde(rename_all = "PascalCase")]
501pub struct Stop {}
502
503#[derive(Serialize, Deserialize, Debug, Clone, Default)]
506#[serde(rename_all = "PascalCase")]
507pub struct CommandResp {
508 #[serde(default, skip_serializing_if = "Option::is_none")]
510 pub error: Option<String>,
511}
512
513pub type StopResp = CommandResp;
514
515rpc_req!(Stop, StopResp, STOP_REQ);
516
517#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
519#[serde(rename_all = "PascalCase")]
520pub struct Restart {}
521
522pub type RestartResp = CommandResp;
523rpc_req!(Restart, RestartResp, RESTART_REQ);
524
525#[derive(Serialize, Deserialize, Debug, Clone, Default)]
527#[serde(rename_all = "PascalCase")]
528pub struct Update {
529 pub version: String,
531 pub permit_major_version: bool,
533}
534
535#[derive(Serialize, Deserialize, Debug, Clone, Default)]
537#[serde(rename_all = "PascalCase")]
538pub struct StopTunnel {
539 #[serde(rename = "Id")]
541 pub client_id: String,
542 pub message: String,
544 pub error_code: Option<String>,
546}
547
548pub type UpdateResp = CommandResp;
549rpc_req!(Update, UpdateResp, UPDATE_REQ);
550
551#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
552#[serde(rename_all = "PascalCase")]
553pub struct SrvInfo {}
554
555#[derive(Serialize, Deserialize, Debug, Clone, Default)]
556#[serde(rename_all = "PascalCase")]
557pub struct SrvInfoResp {
558 pub region: String,
559}
560
561rpc_req!(SrvInfo, SrvInfoResp, SRV_INFO_REQ);
562
563#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
568pub enum ProxyProto {
569 #[default]
571 None,
572 V1,
574 V2,
576}
577
578impl From<ProxyProto> for i64 {
579 fn from(other: ProxyProto) -> Self {
580 use ProxyProto::*;
581 match other {
582 None => 0,
583 V1 => 1,
584 V2 => 2,
585 }
586 }
587}
588
589impl From<i64> for ProxyProto {
590 fn from(other: i64) -> Self {
591 use ProxyProto::*;
592 match other {
593 1 => V1,
594 2 => V2,
595 _ => None,
596 }
597 }
598}
599
600#[derive(Debug, Clone, Error)]
601#[error("invalid proxyproto string: {}", .0)]
602pub struct InvalidProxyProtoString(String);
603
604impl FromStr for ProxyProto {
605 type Err = InvalidProxyProtoString;
606 fn from_str(s: &str) -> Result<Self, Self::Err> {
607 use ProxyProto::*;
608 Ok(match s {
609 "" => None,
610 "1" => V1,
611 "2" => V2,
612 _ => return Err(InvalidProxyProtoString(s.into())),
613 })
614 }
615}
616
617impl Serialize for ProxyProto {
618 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
619 where
620 S: serde::Serializer,
621 {
622 serializer.serialize_i64(i64::from(*self))
623 }
624}
625
626struct ProxyProtoVisitor;
627
628impl<'de> Visitor<'de> for ProxyProtoVisitor {
629 type Value = ProxyProto;
630 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
631 formatter.write_str("0, 1, or 2")
632 }
633
634 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
635 where
636 E: serde::de::Error,
637 {
638 Ok(ProxyProto::from(v))
639 }
640
641 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
642 where
643 E: serde::de::Error,
644 {
645 Ok(ProxyProto::from(v as i64))
646 }
647}
648
649impl<'de> Deserialize<'de> for ProxyProto {
650 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
651 where
652 D: serde::Deserializer<'de>,
653 {
654 deserializer.deserialize_i64(ProxyProtoVisitor)
655 }
656}
657
658#[derive(Debug, Serialize, Deserialize, Clone)]
659#[serde(untagged)]
660pub enum PolicyWrapper {
661 #[serde(serialize_with = "serialize_policy")]
662 Policy(Policy),
663 String(String),
664}
665
666impl From<String> for PolicyWrapper {
667 fn from(value: String) -> Self {
668 PolicyWrapper::String(value)
669 }
670}
671
672#[derive(Serialize, Deserialize, Debug, Clone, Default)]
673#[serde(rename_all = "PascalCase")]
674pub struct HttpEndpoint {
675 pub hostname: String,
676 pub auth: String,
677 pub subdomain: String,
678 pub host_header_rewrite: bool,
679 pub local_url_scheme: Option<String>,
680 pub proxy_proto: ProxyProto,
681
682 pub compression: Option<Compression>,
683 pub circuit_breaker: Option<CircuitBreaker>,
684 #[serde(rename = "IPRestriction")]
685 pub ip_restriction: Option<IpRestriction>,
686 pub basic_auth: Option<BasicAuth>,
687 #[serde(rename = "OAuth")]
688 pub oauth: Option<Oauth>,
689 #[serde(rename = "OIDC")]
690 pub oidc: Option<Oidc>,
691 pub webhook_verification: Option<WebhookVerification>,
692 #[serde(rename = "MutualTLSCA")]
693 pub mutual_tls_ca: Option<MutualTls>,
694 #[serde(default)]
695 pub request_headers: Option<Headers>,
696 #[serde(default)]
697 pub response_headers: Option<Headers>,
698 #[serde(rename = "WebsocketTCPConverter")]
699 pub websocket_tcp_converter: Option<WebsocketTcpConverter>,
700 #[serde(rename = "UserAgentFilter")]
701 pub user_agent_filter: Option<UserAgentFilter>,
702 #[serde(rename = "TrafficPolicy")]
703 pub traffic_policy: Option<PolicyWrapper>,
704}
705
706#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
707pub struct Compression {}
708
709fn is_default<T>(v: &T) -> bool
710where
711 T: PartialEq<T> + Default,
712{
713 T::default() == *v
714}
715
716#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
717pub struct CircuitBreaker {
718 #[serde(default, skip_serializing_if = "is_default")]
719 pub error_threshold: f64,
720}
721
722#[derive(Debug, Clone, Serialize, Deserialize)]
723pub struct BasicAuth {
724 #[serde(default, skip_serializing_if = "is_default")]
725 pub credentials: Vec<BasicAuthCredential>,
726}
727
728#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
729pub struct BasicAuthCredential {
730 pub username: String,
731 #[serde(default, skip_serializing_if = "is_default")]
732 pub cleartext_password: String,
733 #[serde(default, skip_serializing_if = "is_default")]
734 #[serde(with = "base64bytes")]
735 pub hashed_password: Vec<u8>,
736}
737
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct IpRestriction {
740 #[serde(default, skip_serializing_if = "is_default")]
741 pub allow_cidrs: Vec<String>,
742 #[serde(default, skip_serializing_if = "is_default")]
743 pub deny_cidrs: Vec<String>,
744}
745
746#[derive(Debug, Clone, Serialize, Deserialize)]
747pub struct Oauth {
748 pub provider: String,
749 #[serde(default, skip_serializing_if = "is_default")]
750 pub client_id: String,
751 #[serde(default, skip_serializing_if = "is_default")]
752 pub client_secret: SecretString,
753 #[serde(default, skip_serializing_if = "is_default")]
754 #[serde(with = "base64bytes")]
755 pub sealed_client_secret: Vec<u8>,
756 #[serde(default, skip_serializing_if = "is_default")]
757 pub allow_emails: Vec<String>,
758 #[serde(default, skip_serializing_if = "is_default")]
759 pub allow_domains: Vec<String>,
760 #[serde(default, skip_serializing_if = "is_default")]
761 pub scopes: Vec<String>,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct Oidc {
766 pub issuer_url: String,
767 #[serde(default, skip_serializing_if = "is_default")]
768 pub client_id: String,
769 #[serde(default, skip_serializing_if = "is_default")]
770 pub client_secret: SecretString,
771 #[serde(default, skip_serializing_if = "is_default")]
772 #[serde(with = "base64bytes")]
773 pub sealed_client_secret: Vec<u8>,
774 #[serde(default, skip_serializing_if = "is_default")]
775 pub allow_emails: Vec<String>,
776 #[serde(default, skip_serializing_if = "is_default")]
777 pub allow_domains: Vec<String>,
778 #[serde(default, skip_serializing_if = "is_default")]
779 pub scopes: Vec<String>,
780}
781
782#[derive(Debug, Clone, Serialize, Deserialize)]
783pub struct WebhookVerification {
784 pub provider: String,
785 #[serde(default, skip_serializing_if = "is_default")]
786 pub secret: SecretString,
787 #[serde(default, skip_serializing_if = "is_default")]
788 #[serde(with = "base64bytes")]
789 pub sealed_secret: Vec<u8>,
790}
791
792#[derive(Debug, Clone, Serialize, Deserialize)]
793pub struct MutualTls {
794 #[serde(default, skip_serializing_if = "is_default")]
795 #[serde(with = "base64bytes")]
796 pub mutual_tls_ca: Vec<u8>,
798}
799
800#[derive(Debug, Clone, Serialize, Deserialize)]
801#[serde(rename_all = "camelCase")]
802pub struct Headers {
803 #[serde(default, skip_serializing_if = "is_default")]
804 pub add: Vec<String>,
805 #[serde(default, skip_serializing_if = "is_default")]
806 pub remove: Vec<String>,
807 #[serde(default, skip_serializing_if = "is_default")]
808 pub add_parsed: HashMap<String, String>,
809}
810
811#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
812pub struct WebsocketTcpConverter {}
813
814#[derive(Debug, Clone, Serialize, Deserialize)]
815pub struct UserAgentFilter {
816 #[serde(default, skip_serializing_if = "is_default")]
817 pub allow: Vec<String>,
818 #[serde(default, skip_serializing_if = "is_default")]
819 pub deny: Vec<String>,
820}
821
822#[derive(Serialize, Deserialize, Debug, Clone, Default)]
823#[serde(rename_all = "PascalCase")]
824pub struct TcpEndpoint {
825 pub addr: String,
826 pub proxy_proto: ProxyProto,
827 #[serde(rename = "IPRestriction")]
828 pub ip_restriction: Option<IpRestriction>,
829 #[serde(rename = "TrafficPolicy")]
830 pub traffic_policy: Option<PolicyWrapper>,
831}
832
833#[derive(Serialize, Deserialize, Debug, Clone, Default)]
834#[serde(rename_all = "PascalCase")]
835pub struct TlsEndpoint {
836 pub hostname: String,
837 pub subdomain: String,
838 pub proxy_proto: ProxyProto,
839 #[serde(rename = "MutualTLSAtAgent")]
840 pub mutual_tls_at_agent: bool,
841
842 #[serde(rename = "MutualTLSAtEdge")]
843 pub mutual_tls_at_edge: Option<MutualTls>,
844 #[serde(rename = "TLSTermination")]
845 pub tls_termination: Option<TlsTermination>,
846 #[serde(rename = "IPRestriction")]
847 pub ip_restriction: Option<IpRestriction>,
848 #[serde(rename = "TrafficPolicy")]
849 pub traffic_policy: Option<PolicyWrapper>,
850}
851
852#[derive(Serialize, Deserialize, Debug, Clone, Default)]
853pub struct TlsTermination {
854 #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
855 pub cert: Vec<u8>,
856 #[serde(skip_serializing_if = "is_default", default)]
857 pub key: SecretBytes,
858 #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
859 pub sealed_key: Vec<u8>,
860}
861
862#[derive(Serialize, Deserialize, Debug, Clone, Default)]
863#[serde(rename_all = "PascalCase")]
864pub struct LabelEndpoint {
865 pub labels: HashMap<String, String>,
866}
867
868#[derive(Serialize, Deserialize, Debug, Clone, Default)]
869#[serde(rename_all = "PascalCase", default)]
870pub struct Policy {
871 pub inbound: Vec<Rule>,
872 pub outbound: Vec<Rule>,
873}
874
875#[derive(Serialize, Deserialize, Debug, Clone, Default)]
876#[serde(rename_all = "PascalCase", default)]
877pub struct Rule {
878 pub name: String,
879 pub expressions: Vec<String>,
880 pub actions: Vec<Action>,
881}
882
883#[derive(Serialize, Deserialize, Debug, Clone, Default)]
884#[serde(rename_all = "PascalCase", default)]
885pub struct Action {
886 #[serde(rename = "Type")]
887 pub type_: String,
888 #[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")]
889 pub config: Vec<u8>,
890}
891
892fn serialize_policy<S: Serializer>(v: &Policy, s: S) -> Result<S::Ok, S::Error> {
895 let abc = match serde_json::to_string(v) {
896 Ok(t) => t,
897 Err(_) => {
898 return Err(serde::ser::Error::custom(
899 "policy could not be converted to valid json",
900 ))
901 }
902 };
903 s.serialize_str(&abc)
904}
905
906mod vec_to_json {
909 use serde::{
910 Deserialize,
911 Deserializer,
912 Serialize,
913 Serializer,
914 };
915
916 pub fn serialize<S: Serializer>(v: &[u8], s: S) -> Result<S::Ok, S::Error> {
917 let u: serde_json::Value = match serde_json::from_slice(v) {
918 Ok(k) => k,
919 Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")),
920 };
921
922 u.serialize(s)
923 }
924
925 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
926 let s = serde_json::Map::deserialize(d)?;
927 let v = serde_json::to_vec(&s).unwrap();
928 Ok(v)
929 }
930}
931
932mod base64bytes {
935 use base64::prelude::*;
936 use serde::{
937 Deserialize,
938 Deserializer,
939 Serialize,
940 Serializer,
941 };
942
943 pub fn serialize<S: Serializer>(v: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
944 BASE64_STANDARD.encode(v).serialize(s)
945 }
946
947 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
948 let s = String::deserialize(d)?;
949 BASE64_STANDARD
950 .decode(s.as_bytes())
951 .map_err(serde::de::Error::custom)
952 }
953}
954
955#[cfg(test)]
956mod test {
957
958 use super::*;
959
960 #[test]
961 fn test_proxy_proto_serde() {
962 let input = "2";
963
964 let p: ProxyProto = serde_json::from_str(input).unwrap();
965
966 assert!(matches!(p, ProxyProto::V2));
967
968 assert_eq!(serde_json::to_string(&p).unwrap(), "2");
969 }
970
971 pub(crate) const POLICY_JSON: &str = r###"{"Inbound":[{"Name":"test_in","Expressions":["req.Method == 'PUT'"],"Actions":[{"Type":"deny"}]}],"Outbound":[{"Name":"test_out","Expressions":["res.StatusCode == '200'"],"Actions":[{"Type":"custom-response","Config":{"status_code":201}}]}]}"###;
972
973 #[test]
974 fn test_policy_proto_serde() {
975 let policy: Policy = serde_json::from_str(POLICY_JSON).unwrap();
976
977 assert_eq!(1, policy.outbound.len());
980 let outbound = &policy.outbound[0];
981 assert_eq!(1, outbound.actions.len());
982 let action = &outbound.actions[0];
983 assert_eq!(r#"{"status_code":201}"#.as_bytes(), action.config);
984
985 assert_eq!(serde_json::to_string(&policy).unwrap(), POLICY_JSON);
986 }
987}