ngrok/internals/
proto.rs

1use std::{
2    collections::HashMap,
3    error,
4    fmt,
5    io,
6    ops::{
7        Deref,
8        DerefMut,
9    },
10    str::FromStr,
11    string::FromUtf8Error,
12    sync::Arc,
13};
14
15use muxado::typed::StreamType;
16use serde::{
17    de::{
18        DeserializeOwned,
19        Visitor,
20    },
21    Deserialize,
22    Serialize,
23    Serializer,
24};
25use thiserror::Error;
26use tokio::io::{
27    AsyncRead,
28    AsyncReadExt,
29};
30use tracing::debug;
31
32pub const AUTH_REQ: StreamType = StreamType::clamp(0);
33pub const BIND_REQ: StreamType = StreamType::clamp(1);
34pub const UNBIND_REQ: StreamType = StreamType::clamp(2);
35pub const PROXY_REQ: StreamType = StreamType::clamp(3);
36pub const RESTART_REQ: StreamType = StreamType::clamp(4);
37pub const STOP_REQ: StreamType = StreamType::clamp(5);
38pub const UPDATE_REQ: StreamType = StreamType::clamp(6);
39pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7);
40pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9);
41
42pub const VERSION: &[&str] = &["3", "2"]; // integers in priority order
43
44/// An error that may have an ngrok error code.
45/// All ngrok error codes are documented at https://ngrok.com/docs/errors
46pub trait Error: error::Error {
47    /// Return the ngrok error code, if one exists for this error.
48    fn error_code(&self) -> Option<&str> {
49        None
50    }
51    /// Return the error message minus the ngrok error code.
52    /// If this error has no error code, this is equivalent to
53    /// `format!("{error}")`.
54    fn msg(&self) -> String {
55        format!("{self}")
56    }
57}
58
59impl<E> Error for Box<E>
60where
61    E: Error,
62{
63    fn error_code(&self) -> Option<&str> {
64        <E as Error>::error_code(self)
65    }
66    fn msg(&self) -> String {
67        <E as Error>::msg(self)
68    }
69}
70
71impl<E> Error for Arc<E>
72where
73    E: Error,
74{
75    fn error_code(&self) -> Option<&str> {
76        <E as Error>::error_code(self)
77    }
78    fn msg(&self) -> String {
79        <E as Error>::msg(self)
80    }
81}
82
83impl<E> Error for &E
84where
85    E: Error,
86{
87    fn error_code(&self) -> Option<&str> {
88        <E as Error>::error_code(self)
89    }
90    fn msg(&self) -> String {
91        <E as Error>::msg(self)
92    }
93}
94
95#[derive(Serialize, Deserialize, Debug, Clone, Default)]
96pub struct ErrResp {
97    pub msg: String,
98    pub error_code: Option<String>,
99}
100
101impl<'a> From<&'a str> for ErrResp {
102    fn from(value: &'a str) -> Self {
103        let mut error_code = None;
104        let mut msg_lines = vec![];
105        for line in value.lines().filter(|l| !l.is_empty()) {
106            if line.starts_with("ERR_NGROK_") {
107                error_code = Some(line.trim().into());
108            } else {
109                msg_lines.push(line);
110            }
111        }
112        ErrResp {
113            error_code,
114            msg: msg_lines.join("\n"),
115        }
116    }
117}
118
119impl error::Error for ErrResp {}
120
121const ERR_URL: &str = "https://ngrok.com/docs/errors";
122
123impl fmt::Display for ErrResp {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        self.msg.fmt(f)?;
126        if let Some(code) = self.error_code.as_ref().map(|s| s.to_lowercase()) {
127            write!(f, "\n\n{ERR_URL}/{code}")?;
128        }
129        Ok(())
130    }
131}
132
133impl Error for ErrResp {
134    fn error_code(&self) -> Option<&str> {
135        self.error_code.as_deref()
136    }
137    fn msg(&self) -> String {
138        self.msg.clone()
139    }
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, Default)]
143#[serde(rename_all = "PascalCase")]
144pub struct Auth {
145    pub version: Vec<String>, // protocol versions supported, ordered by preference
146    pub client_id: String,    // empty for new sessions
147    pub extra: AuthExtra,     // clients may add whatever data the like to auth messages
148}
149
150#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
151#[serde(transparent)]
152pub struct SecretBytes(#[serde(with = "base64bytes")] Vec<u8>);
153
154impl Deref for SecretBytes {
155    type Target = Vec<u8>;
156    fn deref(&self) -> &Self::Target {
157        &self.0
158    }
159}
160
161impl DerefMut for SecretBytes {
162    fn deref_mut(&mut self) -> &mut Self::Target {
163        &mut self.0
164    }
165}
166
167impl<'a> From<&'a [u8]> for SecretBytes {
168    fn from(other: &'a [u8]) -> Self {
169        SecretBytes(other.into())
170    }
171}
172
173impl From<Vec<u8>> for SecretBytes {
174    fn from(other: Vec<u8>) -> Self {
175        SecretBytes(other)
176    }
177}
178
179impl fmt::Display for SecretBytes {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        write!(f, "********")
182    }
183}
184
185impl fmt::Debug for SecretBytes {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        write!(f, "********")
188    }
189}
190
191#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
192#[serde(transparent)]
193pub struct SecretString(String);
194
195impl Deref for SecretString {
196    type Target = String;
197    fn deref(&self) -> &Self::Target {
198        &self.0
199    }
200}
201
202impl DerefMut for SecretString {
203    fn deref_mut(&mut self) -> &mut Self::Target {
204        &mut self.0
205    }
206}
207
208impl<'a> From<&'a str> for SecretString {
209    fn from(other: &'a str) -> Self {
210        SecretString(other.into())
211    }
212}
213
214impl From<String> for SecretString {
215    fn from(other: String) -> Self {
216        SecretString(other)
217    }
218}
219
220impl fmt::Display for SecretString {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        write!(f, "********")
223    }
224}
225
226impl fmt::Debug for SecretString {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        write!(f, "********")
229    }
230}
231
232#[derive(Serialize, Deserialize, Debug, Clone, Default)]
233#[serde(rename_all = "PascalCase")]
234pub struct AuthExtra {
235    #[serde(rename = "OS")]
236    pub os: String,
237    pub arch: String,
238    pub auth_token: SecretString,
239    pub version: String,
240    pub hostname: String,
241    pub user_agent: String,
242    pub metadata: String,
243    pub cookie: SecretString,
244    pub heartbeat_interval: i64,
245    pub heartbeat_tolerance: i64,
246
247    // for each remote operation, these variables define whether the ngrok
248    // client is capable of executing that operation. each capability
249    // is transmitted as a pointer to String, with the following meanings:
250    //
251    // null ->               operation disallow beause the ngrok agent version is too old.
252    //                       this is true because older clients will never set this value
253    //
254    // "" (empty String)  -> the operation is supported
255    //
256    // non-empty String   -> the operation is not supported and this value is the  user-facing
257    //                       error message describing why it is not supported
258    pub update_unsupported_error: Option<String>,
259    pub stop_unsupported_error: Option<String>,
260    pub restart_unsupported_error: Option<String>,
261
262    pub proxy_type: String,
263    #[serde(rename = "MutualTLS")]
264    pub mutual_tls: bool,
265    pub service_run: bool,
266    pub config_version: String,
267    pub custom_interface: bool,
268    #[serde(rename = "CustomCAs")]
269    pub custom_cas: bool,
270
271    pub client_type: String,
272}
273
274#[derive(Serialize, Deserialize, Debug, Clone)]
275#[serde(rename_all = "PascalCase")]
276pub struct AuthResp {
277    pub version: String,
278    pub client_id: String,
279    #[serde(default)]
280    pub extra: AuthRespExtra,
281}
282
283rpc_req!(Auth, AuthResp, AUTH_REQ);
284
285#[derive(Serialize, Deserialize, Debug, Clone, Default)]
286#[serde(rename_all = "PascalCase")]
287pub struct AuthRespExtra {
288    pub version: Option<String>,
289    pub region: Option<String>,
290    pub cookie: Option<SecretString>,
291    pub account_name: Option<String>,
292    pub session_duration: Option<i64>,
293    pub plan_name: Option<String>,
294    pub banner: Option<String>,
295}
296
297#[derive(Serialize, Deserialize, Debug, Clone)]
298#[serde(rename_all = "PascalCase")]
299pub struct Bind<T> {
300    #[serde(rename = "Id")]
301    pub client_id: String,
302    pub proto: String,
303    pub forwards_to: String,
304    pub forwards_proto: String,
305    pub opts: T,
306    pub extra: BindExtra,
307}
308
309#[derive(Debug, Clone)]
310// allowing this since these aren't persistent values.
311#[allow(clippy::large_enum_variant)]
312pub enum BindOpts {
313    Http(HttpEndpoint),
314    Tcp(TcpEndpoint),
315    Tls(TlsEndpoint),
316}
317
318#[derive(Serialize, Deserialize, Debug, Clone, Default)]
319#[serde(rename_all = "PascalCase")]
320pub struct BindExtra {
321    pub token: SecretString,
322    #[serde(rename = "IPPolicyRef")]
323    pub ip_policy_ref: String,
324    pub metadata: String,
325    pub bindings: Vec<String>,
326    #[serde(rename = "PoolingEnabled")]
327    pub pooling_enabled: bool,
328}
329
330#[derive(Serialize, Deserialize, Debug, Clone)]
331#[serde(rename_all = "PascalCase")]
332pub struct BindResp<T> {
333    #[serde(rename = "Id")]
334    pub client_id: String,
335    #[serde(rename = "URL")]
336    pub url: String,
337    pub proto: String,
338    #[serde(rename = "Opts")]
339    pub bind_opts: T,
340    pub extra: BindRespExtra,
341}
342
343#[derive(Serialize, Deserialize, Debug, Clone)]
344#[serde(rename_all = "PascalCase")]
345pub struct BindRespExtra {
346    pub token: SecretString,
347}
348
349rpc_req!(Bind<T>, BindResp<T>, BIND_REQ; T: std::fmt::Debug + Serialize + DeserializeOwned + Clone);
350
351#[derive(Serialize, Deserialize, Debug, Clone)]
352#[serde(rename_all = "PascalCase")]
353pub struct StartTunnelWithLabel {
354    pub labels: HashMap<String, String>,
355    pub forwards_to: String,
356    pub forwards_proto: String,
357    pub metadata: String,
358}
359
360#[derive(Serialize, Deserialize, Debug, Clone)]
361#[serde(rename_all = "PascalCase")]
362pub struct StartTunnelWithLabelResp {
363    pub id: String,
364}
365
366rpc_req!(
367    StartTunnelWithLabel,
368    StartTunnelWithLabelResp,
369    BIND_LABELED_REQ
370);
371
372#[derive(Serialize, Deserialize, Debug, Clone)]
373#[serde(rename_all = "PascalCase")]
374pub struct Unbind {
375    #[serde(rename = "Id")]
376    pub client_id: String,
377    // extra: not sure what this field actually contains
378}
379
380#[derive(Serialize, Deserialize, Debug, Clone)]
381#[serde(rename_all = "PascalCase")]
382pub struct UnbindResp {
383    // extra: not sure what this field actually contains
384}
385
386rpc_req!(Unbind, UnbindResp, UNBIND_REQ);
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "PascalCase")]
390pub struct ProxyHeader {
391    pub id: String,
392    pub client_addr: String,
393    pub proto: String,
394    pub edge_type: EdgeType,
395    #[serde(rename = "PassthroughTLS")]
396    pub passthrough_tls: bool,
397}
398
399#[derive(Error, Debug)]
400#[non_exhaustive]
401pub enum ReadHeaderError {
402    #[error("error reading proxy header")]
403    Io(#[from] io::Error),
404    #[error("invalid utf-8 in proxy header")]
405    InvalidUtf8(#[from] FromUtf8Error),
406    #[error("invalid proxy header json")]
407    InvalidHeader(#[from] serde_json::Error),
408}
409
410impl ProxyHeader {
411    pub async fn read_from_stream(
412        mut stream: impl AsyncRead + Unpin,
413    ) -> Result<Self, ReadHeaderError> {
414        let size = stream.read_i64_le().await?;
415        let mut buf = vec![0u8; size as usize];
416
417        stream.read_exact(&mut buf).await?;
418
419        let header = String::from_utf8(buf)?;
420
421        debug!(?header, "read header");
422
423        Ok(serde_json::from_str(&header)?)
424    }
425}
426
427/// The edge type for an incomming connection.
428#[derive(Copy, Clone, Debug, PartialEq, Eq)]
429pub enum EdgeType {
430    /// EdgeType Undefined
431    Undefined,
432    /// A TCP Edge
433    Tcp,
434    /// A TLS Edge
435    Tls,
436    /// A HTTPs Edge
437    Https,
438}
439
440impl FromStr for EdgeType {
441    type Err = ();
442    fn from_str(s: &str) -> Result<Self, Self::Err> {
443        Ok(match s {
444            "1" => EdgeType::Tcp,
445            "2" => EdgeType::Tls,
446            "3" => EdgeType::Https,
447            _ => EdgeType::Undefined,
448        })
449    }
450}
451
452impl EdgeType {
453    pub(crate) fn as_str(self) -> &'static str {
454        match self {
455            EdgeType::Undefined => "0",
456            EdgeType::Tcp => "1",
457            EdgeType::Tls => "2",
458            EdgeType::Https => "3",
459        }
460    }
461}
462
463impl Serialize for EdgeType {
464    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
465    where
466        S: serde::Serializer,
467    {
468        serializer.serialize_str(self.as_str())
469    }
470}
471
472struct EdgeTypeVisitor;
473
474impl<'de> Visitor<'de> for EdgeTypeVisitor {
475    type Value = EdgeType;
476    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
477        formatter.write_str(r#""0", "1", "2", or "3""#)
478    }
479
480    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
481    where
482        E: serde::de::Error,
483    {
484        Ok(EdgeType::from_str(v).unwrap())
485    }
486}
487
488impl<'de> Deserialize<'de> for EdgeType {
489    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
490    where
491        D: serde::Deserializer<'de>,
492    {
493        deserializer.deserialize_str(EdgeTypeVisitor)
494    }
495}
496
497/// A request from the ngrok dashboard for the agent to stop.
498#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
499#[serde(rename_all = "PascalCase")]
500pub struct Stop {}
501
502/// Common response structure for all remote commands originating from the ngrok
503/// dashboard.
504#[derive(Serialize, Deserialize, Debug, Clone, Default)]
505#[serde(rename_all = "PascalCase")]
506pub struct CommandResp {
507    /// The error arising from command handling, if any.
508    #[serde(default, skip_serializing_if = "Option::is_none")]
509    pub error: Option<String>,
510}
511
512pub type StopResp = CommandResp;
513
514rpc_req!(Stop, StopResp, STOP_REQ);
515
516/// A request from the ngrok dashboard for the agent to restart.
517#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
518#[serde(rename_all = "PascalCase")]
519pub struct Restart {}
520
521pub type RestartResp = CommandResp;
522rpc_req!(Restart, RestartResp, RESTART_REQ);
523
524/// A request from the ngrok dashboard for the agent to update itself.
525#[derive(Serialize, Deserialize, Debug, Clone, Default)]
526#[serde(rename_all = "PascalCase")]
527pub struct Update {
528    /// The version that the agent is requested to update to.
529    pub version: String,
530    /// Whether or not updating to the same major version is sufficient.
531    pub permit_major_version: bool,
532}
533
534/// A request from remote to stop a tunnel
535#[derive(Serialize, Deserialize, Debug, Clone, Default)]
536#[serde(rename_all = "PascalCase")]
537pub struct StopTunnel {
538    /// The id of the tunnel to stop
539    #[serde(rename = "Id")]
540    pub client_id: String,
541    /// The message on why this tunnel was stopped
542    pub message: String,
543    /// An optional ngrok error code
544    pub error_code: Option<String>,
545}
546
547pub type UpdateResp = CommandResp;
548rpc_req!(Update, UpdateResp, UPDATE_REQ);
549
550/// The version of [PROXY protocol](https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
551/// to use with this tunnel.
552///
553/// [ProxyProto::None] disables PROXY protocol support.
554#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
555pub enum ProxyProto {
556    /// No PROXY protocol
557    #[default]
558    None,
559    /// PROXY protocol v1
560    V1,
561    /// PROXY protocol v2
562    V2,
563}
564
565impl From<ProxyProto> for i64 {
566    fn from(other: ProxyProto) -> Self {
567        use ProxyProto::*;
568        match other {
569            None => 0,
570            V1 => 1,
571            V2 => 2,
572        }
573    }
574}
575
576impl From<i64> for ProxyProto {
577    fn from(other: i64) -> Self {
578        use ProxyProto::*;
579        match other {
580            1 => V1,
581            2 => V2,
582            _ => None,
583        }
584    }
585}
586
587#[derive(Debug, Clone, Error)]
588#[error("invalid proxyproto string: {}", .0)]
589pub struct InvalidProxyProtoString(String);
590
591impl FromStr for ProxyProto {
592    type Err = InvalidProxyProtoString;
593    fn from_str(s: &str) -> Result<Self, Self::Err> {
594        use ProxyProto::*;
595        Ok(match s {
596            "" => None,
597            "1" => V1,
598            "2" => V2,
599            _ => return Err(InvalidProxyProtoString(s.into())),
600        })
601    }
602}
603
604impl Serialize for ProxyProto {
605    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
606    where
607        S: serde::Serializer,
608    {
609        serializer.serialize_i64(i64::from(*self))
610    }
611}
612
613struct ProxyProtoVisitor;
614
615impl<'de> Visitor<'de> for ProxyProtoVisitor {
616    type Value = ProxyProto;
617    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
618        formatter.write_str("0, 1, or 2")
619    }
620
621    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
622    where
623        E: serde::de::Error,
624    {
625        Ok(ProxyProto::from(v))
626    }
627
628    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
629    where
630        E: serde::de::Error,
631    {
632        Ok(ProxyProto::from(v as i64))
633    }
634}
635
636impl<'de> Deserialize<'de> for ProxyProto {
637    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
638    where
639        D: serde::Deserializer<'de>,
640    {
641        deserializer.deserialize_i64(ProxyProtoVisitor)
642    }
643}
644
645#[derive(Debug, Serialize, Deserialize, Clone)]
646#[serde(untagged)]
647pub enum PolicyWrapper {
648    #[serde(serialize_with = "serialize_policy")]
649    Policy(Policy),
650    String(String),
651}
652
653impl From<String> for PolicyWrapper {
654    fn from(value: String) -> Self {
655        PolicyWrapper::String(value)
656    }
657}
658
659#[derive(Serialize, Deserialize, Debug, Clone, Default)]
660#[serde(rename_all = "PascalCase")]
661pub struct HttpEndpoint {
662    #[serde(default)]
663    pub domain: String,
664    pub hostname: String,
665    pub auth: String,
666    pub subdomain: String,
667    pub host_header_rewrite: bool,
668    pub local_url_scheme: Option<String>,
669    pub proxy_proto: ProxyProto,
670
671    pub compression: Option<Compression>,
672    pub circuit_breaker: Option<CircuitBreaker>,
673    #[serde(rename = "IPRestriction")]
674    pub ip_restriction: Option<IpRestriction>,
675    pub basic_auth: Option<BasicAuth>,
676    #[serde(rename = "OAuth")]
677    pub oauth: Option<Oauth>,
678    #[serde(rename = "OIDC")]
679    pub oidc: Option<Oidc>,
680    pub webhook_verification: Option<WebhookVerification>,
681    #[serde(rename = "MutualTLSCA")]
682    pub mutual_tls_ca: Option<MutualTls>,
683    #[serde(default)]
684    pub request_headers: Option<Headers>,
685    #[serde(default)]
686    pub response_headers: Option<Headers>,
687    #[serde(rename = "WebsocketTCPConverter")]
688    pub websocket_tcp_converter: Option<WebsocketTcpConverter>,
689    #[serde(rename = "UserAgentFilter")]
690    pub user_agent_filter: Option<UserAgentFilter>,
691    #[serde(rename = "TrafficPolicy")]
692    pub traffic_policy: Option<PolicyWrapper>,
693}
694
695#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
696pub struct Compression {}
697
698fn is_default<T>(v: &T) -> bool
699where
700    T: PartialEq<T> + Default,
701{
702    T::default() == *v
703}
704
705#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
706pub struct CircuitBreaker {
707    #[serde(default, skip_serializing_if = "is_default")]
708    pub error_threshold: f64,
709}
710
711#[derive(Debug, Clone, Serialize, Deserialize)]
712pub struct BasicAuth {
713    #[serde(default, skip_serializing_if = "is_default")]
714    pub credentials: Vec<BasicAuthCredential>,
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
718pub struct BasicAuthCredential {
719    pub username: String,
720    #[serde(default, skip_serializing_if = "is_default")]
721    pub cleartext_password: String,
722    #[serde(default, skip_serializing_if = "is_default")]
723    #[serde(with = "base64bytes")]
724    pub hashed_password: Vec<u8>,
725}
726
727#[derive(Debug, Clone, Serialize, Deserialize)]
728pub struct IpRestriction {
729    #[serde(default, skip_serializing_if = "is_default")]
730    pub allow_cidrs: Vec<String>,
731    #[serde(default, skip_serializing_if = "is_default")]
732    pub deny_cidrs: Vec<String>,
733}
734
735#[derive(Debug, Clone, Serialize, Deserialize)]
736pub struct Oauth {
737    pub provider: String,
738    #[serde(default, skip_serializing_if = "is_default")]
739    pub client_id: String,
740    #[serde(default, skip_serializing_if = "is_default")]
741    pub client_secret: SecretString,
742    #[serde(default, skip_serializing_if = "is_default")]
743    #[serde(with = "base64bytes")]
744    pub sealed_client_secret: Vec<u8>,
745    #[serde(default, skip_serializing_if = "is_default")]
746    pub allow_emails: Vec<String>,
747    #[serde(default, skip_serializing_if = "is_default")]
748    pub allow_domains: Vec<String>,
749    #[serde(default, skip_serializing_if = "is_default")]
750    pub scopes: Vec<String>,
751}
752
753#[derive(Debug, Clone, Serialize, Deserialize)]
754pub struct Oidc {
755    pub issuer_url: String,
756    #[serde(default, skip_serializing_if = "is_default")]
757    pub client_id: String,
758    #[serde(default, skip_serializing_if = "is_default")]
759    pub client_secret: SecretString,
760    #[serde(default, skip_serializing_if = "is_default")]
761    #[serde(with = "base64bytes")]
762    pub sealed_client_secret: Vec<u8>,
763    #[serde(default, skip_serializing_if = "is_default")]
764    pub allow_emails: Vec<String>,
765    #[serde(default, skip_serializing_if = "is_default")]
766    pub allow_domains: Vec<String>,
767    #[serde(default, skip_serializing_if = "is_default")]
768    pub scopes: Vec<String>,
769}
770
771#[derive(Debug, Clone, Serialize, Deserialize)]
772pub struct WebhookVerification {
773    pub provider: String,
774    #[serde(default, skip_serializing_if = "is_default")]
775    pub secret: SecretString,
776    #[serde(default, skip_serializing_if = "is_default")]
777    #[serde(with = "base64bytes")]
778    pub sealed_secret: Vec<u8>,
779}
780
781#[derive(Debug, Clone, Serialize, Deserialize)]
782pub struct MutualTls {
783    #[serde(default, skip_serializing_if = "is_default")]
784    #[serde(with = "base64bytes")]
785    // this is snake-case on the wire
786    pub mutual_tls_ca: Vec<u8>,
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
790#[serde(rename_all = "camelCase")]
791pub struct Headers {
792    #[serde(default, skip_serializing_if = "is_default")]
793    pub add: Vec<String>,
794    #[serde(default, skip_serializing_if = "is_default")]
795    pub remove: Vec<String>,
796    #[serde(default, skip_serializing_if = "is_default")]
797    pub add_parsed: HashMap<String, String>,
798}
799
800#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
801pub struct WebsocketTcpConverter {}
802
803#[derive(Debug, Clone, Serialize, Deserialize)]
804pub struct UserAgentFilter {
805    #[serde(default, skip_serializing_if = "is_default")]
806    pub allow: Vec<String>,
807    #[serde(default, skip_serializing_if = "is_default")]
808    pub deny: Vec<String>,
809}
810
811#[derive(Serialize, Deserialize, Debug, Clone, Default)]
812#[serde(rename_all = "PascalCase")]
813pub struct TcpEndpoint {
814    pub addr: String,
815    pub proxy_proto: ProxyProto,
816    #[serde(rename = "IPRestriction")]
817    pub ip_restriction: Option<IpRestriction>,
818    #[serde(rename = "TrafficPolicy")]
819    pub traffic_policy: Option<PolicyWrapper>,
820}
821
822#[derive(Serialize, Deserialize, Debug, Clone, Default)]
823#[serde(rename_all = "PascalCase")]
824pub struct TlsEndpoint {
825    #[serde(default)]
826    pub domain: String,
827    pub hostname: String,
828    pub subdomain: String,
829    pub proxy_proto: ProxyProto,
830    #[serde(rename = "MutualTLSAtAgent")]
831    pub mutual_tls_at_agent: bool,
832
833    #[serde(rename = "MutualTLSAtEdge")]
834    pub mutual_tls_at_edge: Option<MutualTls>,
835    #[serde(rename = "TLSTermination")]
836    pub tls_termination: Option<TlsTermination>,
837    #[serde(rename = "IPRestriction")]
838    pub ip_restriction: Option<IpRestriction>,
839    #[serde(rename = "TrafficPolicy")]
840    pub traffic_policy: Option<PolicyWrapper>,
841}
842
843#[derive(Serialize, Deserialize, Debug, Clone, Default)]
844pub struct TlsTermination {
845    #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
846    pub cert: Vec<u8>,
847    #[serde(skip_serializing_if = "is_default", default)]
848    pub key: SecretBytes,
849    #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
850    pub sealed_key: Vec<u8>,
851}
852
853#[derive(Serialize, Deserialize, Debug, Clone, Default)]
854#[serde(rename_all = "PascalCase", default)]
855pub struct Policy {
856    pub inbound: Vec<Rule>,
857    pub outbound: Vec<Rule>,
858}
859
860#[derive(Serialize, Deserialize, Debug, Clone, Default)]
861#[serde(rename_all = "PascalCase", default)]
862pub struct Rule {
863    pub name: String,
864    pub expressions: Vec<String>,
865    pub actions: Vec<Action>,
866}
867
868#[derive(Serialize, Deserialize, Debug, Clone, Default)]
869#[serde(rename_all = "PascalCase", default)]
870pub struct Action {
871    #[serde(rename = "Type")]
872    pub type_: String,
873    #[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")]
874    pub config: Vec<u8>,
875}
876
877// This function converts a Policy into a valid JSON string. This is used so legacy configurations will still work
878// using the new string "TrafficPolicy" field.
879fn serialize_policy<S: Serializer>(v: &Policy, s: S) -> Result<S::Ok, S::Error> {
880    let abc = match serde_json::to_string(v) {
881        Ok(t) => t,
882        Err(_) => {
883            return Err(serde::ser::Error::custom(
884                "policy could not be converted to valid json",
885            ))
886        }
887    };
888    s.serialize_str(&abc)
889}
890
891// These are helpers to convert base64 strings to full, real json. The serialize helper also ensures that the resulting
892// representation isn't a string-escaped string.
893mod vec_to_json {
894    use serde::{
895        Deserialize,
896        Deserializer,
897        Serialize,
898        Serializer,
899    };
900
901    pub fn serialize<S: Serializer>(v: &[u8], s: S) -> Result<S::Ok, S::Error> {
902        let u: serde_json::Value = match serde_json::from_slice(v) {
903            Ok(k) => k,
904            Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")),
905        };
906
907        u.serialize(s)
908    }
909
910    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
911        let s = serde_json::Map::deserialize(d)?;
912        let v = serde_json::to_vec(&s).unwrap();
913        Ok(v)
914    }
915}
916
917// These are helpers to facilitate the Vec<u8> <-> base64-encoded bytes
918// representation that the Go messages use
919mod base64bytes {
920    use base64::prelude::*;
921    use serde::{
922        Deserialize,
923        Deserializer,
924        Serialize,
925        Serializer,
926    };
927
928    pub fn serialize<S: Serializer>(v: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
929        BASE64_STANDARD.encode(v).serialize(s)
930    }
931
932    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
933        let s = String::deserialize(d)?;
934        BASE64_STANDARD
935            .decode(s.as_bytes())
936            .map_err(serde::de::Error::custom)
937    }
938}
939
940#[cfg(test)]
941mod test {
942
943    use super::*;
944
945    #[test]
946    fn test_proxy_proto_serde() {
947        let input = "2";
948
949        let p: ProxyProto = serde_json::from_str(input).unwrap();
950
951        assert!(matches!(p, ProxyProto::V2));
952
953        assert_eq!(serde_json::to_string(&p).unwrap(), "2");
954    }
955
956    pub(crate) const POLICY_JSON: &str = r###"{"Inbound":[{"Name":"test_in","Expressions":["req.Method == 'PUT'"],"Actions":[{"Type":"deny"}]}],"Outbound":[{"Name":"test_out","Expressions":["res.StatusCode == '200'"],"Actions":[{"Type":"custom-response","Config":{"status_code":201}}]}]}"###;
957
958    #[test]
959    fn test_policy_proto_serde() {
960        let policy: Policy = serde_json::from_str(POLICY_JSON).unwrap();
961
962        // mainly just interested in checking outbound, as that has the
963        // special vec serialization
964        assert_eq!(1, policy.outbound.len());
965        let outbound = &policy.outbound[0];
966        assert_eq!(1, outbound.actions.len());
967        let action = &outbound.actions[0];
968        assert_eq!(r#"{"status_code":201}"#.as_bytes(), action.config);
969
970        assert_eq!(serde_json::to_string(&policy).unwrap(), POLICY_JSON);
971    }
972}