ngrok/internals/
proto.rs

1use std::{
2    collections::HashMap,
3    error,
4    fmt,
5    io,
6    ops::{
7        Deref,
8        DerefMut,
9    },
10    str::FromStr,
11    string::FromUtf8Error,
12    sync::Arc,
13};
14
15use muxado::typed::StreamType;
16use serde::{
17    de::{
18        DeserializeOwned,
19        Visitor,
20    },
21    Deserialize,
22    Serialize,
23    Serializer,
24};
25use thiserror::Error;
26use tokio::io::{
27    AsyncRead,
28    AsyncReadExt,
29};
30use tracing::debug;
31
32pub const AUTH_REQ: StreamType = StreamType::clamp(0);
33pub const BIND_REQ: StreamType = StreamType::clamp(1);
34pub const UNBIND_REQ: StreamType = StreamType::clamp(2);
35pub const PROXY_REQ: StreamType = StreamType::clamp(3);
36pub const RESTART_REQ: StreamType = StreamType::clamp(4);
37pub const STOP_REQ: StreamType = StreamType::clamp(5);
38pub const UPDATE_REQ: StreamType = StreamType::clamp(6);
39pub const BIND_LABELED_REQ: StreamType = StreamType::clamp(7);
40pub const SRV_INFO_REQ: StreamType = StreamType::clamp(8);
41pub const STOP_TUNNEL_REQ: StreamType = StreamType::clamp(9);
42
43pub const VERSION: &[&str] = &["3", "2"]; // integers in priority order
44
45/// An error that may have an ngrok error code.
46/// All ngrok error codes are documented at https://ngrok.com/docs/errors
47pub trait Error: error::Error {
48    /// Return the ngrok error code, if one exists for this error.
49    fn error_code(&self) -> Option<&str> {
50        None
51    }
52    /// Return the error message minus the ngrok error code.
53    /// If this error has no error code, this is equivalent to
54    /// `format!("{error}")`.
55    fn msg(&self) -> String {
56        format!("{self}")
57    }
58}
59
60impl<E> Error for Box<E>
61where
62    E: Error,
63{
64    fn error_code(&self) -> Option<&str> {
65        <E as Error>::error_code(self)
66    }
67    fn msg(&self) -> String {
68        <E as Error>::msg(self)
69    }
70}
71
72impl<E> Error for Arc<E>
73where
74    E: Error,
75{
76    fn error_code(&self) -> Option<&str> {
77        <E as Error>::error_code(self)
78    }
79    fn msg(&self) -> String {
80        <E as Error>::msg(self)
81    }
82}
83
84impl<E> Error for &E
85where
86    E: Error,
87{
88    fn error_code(&self) -> Option<&str> {
89        <E as Error>::error_code(self)
90    }
91    fn msg(&self) -> String {
92        <E as Error>::msg(self)
93    }
94}
95
96#[derive(Serialize, Deserialize, Debug, Clone, Default)]
97pub struct ErrResp {
98    pub msg: String,
99    pub error_code: Option<String>,
100}
101
102impl<'a> From<&'a str> for ErrResp {
103    fn from(value: &'a str) -> Self {
104        let mut error_code = None;
105        let mut msg_lines = vec![];
106        for line in value.lines().filter(|l| !l.is_empty()) {
107            if line.starts_with("ERR_NGROK_") {
108                error_code = Some(line.trim().into());
109            } else {
110                msg_lines.push(line);
111            }
112        }
113        ErrResp {
114            error_code,
115            msg: msg_lines.join("\n"),
116        }
117    }
118}
119
120impl error::Error for ErrResp {}
121
122const ERR_URL: &str = "https://ngrok.com/docs/errors";
123
124impl fmt::Display for ErrResp {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        self.msg.fmt(f)?;
127        if let Some(code) = self.error_code.as_ref().map(|s| s.to_lowercase()) {
128            write!(f, "\n\n{ERR_URL}/{code}")?;
129        }
130        Ok(())
131    }
132}
133
134impl Error for ErrResp {
135    fn error_code(&self) -> Option<&str> {
136        self.error_code.as_deref()
137    }
138    fn msg(&self) -> String {
139        self.msg.clone()
140    }
141}
142
143#[derive(Serialize, Deserialize, Debug, Clone, Default)]
144#[serde(rename_all = "PascalCase")]
145pub struct Auth {
146    pub version: Vec<String>, // protocol versions supported, ordered by preference
147    pub client_id: String,    // empty for new sessions
148    pub extra: AuthExtra,     // clients may add whatever data the like to auth messages
149}
150
151#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
152#[serde(transparent)]
153pub struct SecretBytes(#[serde(with = "base64bytes")] Vec<u8>);
154
155impl Deref for SecretBytes {
156    type Target = Vec<u8>;
157    fn deref(&self) -> &Self::Target {
158        &self.0
159    }
160}
161
162impl DerefMut for SecretBytes {
163    fn deref_mut(&mut self) -> &mut Self::Target {
164        &mut self.0
165    }
166}
167
168impl<'a> From<&'a [u8]> for SecretBytes {
169    fn from(other: &'a [u8]) -> Self {
170        SecretBytes(other.into())
171    }
172}
173
174impl From<Vec<u8>> for SecretBytes {
175    fn from(other: Vec<u8>) -> Self {
176        SecretBytes(other)
177    }
178}
179
180impl fmt::Display for SecretBytes {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        write!(f, "********")
183    }
184}
185
186impl fmt::Debug for SecretBytes {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(f, "********")
189    }
190}
191
192#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Default)]
193#[serde(transparent)]
194pub struct SecretString(String);
195
196impl Deref for SecretString {
197    type Target = String;
198    fn deref(&self) -> &Self::Target {
199        &self.0
200    }
201}
202
203impl DerefMut for SecretString {
204    fn deref_mut(&mut self) -> &mut Self::Target {
205        &mut self.0
206    }
207}
208
209impl<'a> From<&'a str> for SecretString {
210    fn from(other: &'a str) -> Self {
211        SecretString(other.into())
212    }
213}
214
215impl From<String> for SecretString {
216    fn from(other: String) -> Self {
217        SecretString(other)
218    }
219}
220
221impl fmt::Display for SecretString {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        write!(f, "********")
224    }
225}
226
227impl fmt::Debug for SecretString {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        write!(f, "********")
230    }
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone, Default)]
234#[serde(rename_all = "PascalCase")]
235pub struct AuthExtra {
236    #[serde(rename = "OS")]
237    pub os: String,
238    pub arch: String,
239    pub auth_token: SecretString,
240    pub version: String,
241    pub hostname: String,
242    pub user_agent: String,
243    pub metadata: String,
244    pub cookie: SecretString,
245    pub heartbeat_interval: i64,
246    pub heartbeat_tolerance: i64,
247
248    // for each remote operation, these variables define whether the ngrok
249    // client is capable of executing that operation. each capability
250    // is transmitted as a pointer to String, with the following meanings:
251    //
252    // null ->               operation disallow beause the ngrok agent version is too old.
253    //                       this is true because older clients will never set this value
254    //
255    // "" (empty String)  -> the operation is supported
256    //
257    // non-empty String   -> the operation is not supported and this value is the  user-facing
258    //                       error message describing why it is not supported
259    pub update_unsupported_error: Option<String>,
260    pub stop_unsupported_error: Option<String>,
261    pub restart_unsupported_error: Option<String>,
262
263    pub proxy_type: String,
264    #[serde(rename = "MutualTLS")]
265    pub mutual_tls: bool,
266    pub service_run: bool,
267    pub config_version: String,
268    pub custom_interface: bool,
269    #[serde(rename = "CustomCAs")]
270    pub custom_cas: bool,
271
272    pub client_type: String,
273}
274
275#[derive(Serialize, Deserialize, Debug, Clone)]
276#[serde(rename_all = "PascalCase")]
277pub struct AuthResp {
278    pub version: String,
279    pub client_id: String,
280    #[serde(default)]
281    pub extra: AuthRespExtra,
282}
283
284rpc_req!(Auth, AuthResp, AUTH_REQ);
285
286#[derive(Serialize, Deserialize, Debug, Clone, Default)]
287#[serde(rename_all = "PascalCase")]
288pub struct AuthRespExtra {
289    pub version: Option<String>,
290    pub region: Option<String>,
291    pub cookie: Option<SecretString>,
292    pub account_name: Option<String>,
293    pub session_duration: Option<i64>,
294    pub plan_name: Option<String>,
295    pub banner: Option<String>,
296}
297
298#[derive(Serialize, Deserialize, Debug, Clone)]
299#[serde(rename_all = "PascalCase")]
300pub struct Bind<T> {
301    #[serde(rename = "Id")]
302    pub client_id: String,
303    pub proto: String,
304    pub forwards_to: String,
305    pub forwards_proto: String,
306    pub opts: T,
307    pub extra: BindExtra,
308}
309
310#[derive(Debug, Clone)]
311// allowing this since these aren't persistent values.
312#[allow(clippy::large_enum_variant)]
313pub enum BindOpts {
314    Http(HttpEndpoint),
315    Tcp(TcpEndpoint),
316    Tls(TlsEndpoint),
317}
318
319#[derive(Serialize, Deserialize, Debug, Clone, Default)]
320#[serde(rename_all = "PascalCase")]
321pub struct BindExtra {
322    pub token: SecretString,
323    #[serde(rename = "IPPolicyRef")]
324    pub ip_policy_ref: String,
325    pub metadata: String,
326    pub bindings: Vec<String>,
327    #[serde(rename = "PoolingEnabled")]
328    pub pooling_enabled: bool,
329}
330
331#[derive(Serialize, Deserialize, Debug, Clone)]
332#[serde(rename_all = "PascalCase")]
333pub struct BindResp<T> {
334    #[serde(rename = "Id")]
335    pub client_id: String,
336    #[serde(rename = "URL")]
337    pub url: String,
338    pub proto: String,
339    #[serde(rename = "Opts")]
340    pub bind_opts: T,
341    pub extra: BindRespExtra,
342}
343
344#[derive(Serialize, Deserialize, Debug, Clone)]
345#[serde(rename_all = "PascalCase")]
346pub struct BindRespExtra {
347    pub token: SecretString,
348}
349
350rpc_req!(Bind<T>, BindResp<T>, BIND_REQ; T: std::fmt::Debug + Serialize + DeserializeOwned + Clone);
351
352#[derive(Serialize, Deserialize, Debug, Clone)]
353#[serde(rename_all = "PascalCase")]
354pub struct StartTunnelWithLabel {
355    pub labels: HashMap<String, String>,
356    pub forwards_to: String,
357    pub forwards_proto: String,
358    pub metadata: String,
359}
360
361#[derive(Serialize, Deserialize, Debug, Clone)]
362#[serde(rename_all = "PascalCase")]
363pub struct StartTunnelWithLabelResp {
364    pub id: String,
365}
366
367rpc_req!(
368    StartTunnelWithLabel,
369    StartTunnelWithLabelResp,
370    BIND_LABELED_REQ
371);
372
373#[derive(Serialize, Deserialize, Debug, Clone)]
374#[serde(rename_all = "PascalCase")]
375pub struct Unbind {
376    #[serde(rename = "Id")]
377    pub client_id: String,
378    // extra: not sure what this field actually contains
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382#[serde(rename_all = "PascalCase")]
383pub struct UnbindResp {
384    // extra: not sure what this field actually contains
385}
386
387rpc_req!(Unbind, UnbindResp, UNBIND_REQ);
388
389#[derive(Serialize, Deserialize, Debug, Clone)]
390#[serde(rename_all = "PascalCase")]
391pub struct ProxyHeader {
392    pub id: String,
393    pub client_addr: String,
394    pub proto: String,
395    pub edge_type: EdgeType,
396    #[serde(rename = "PassthroughTLS")]
397    pub passthrough_tls: bool,
398}
399
400#[derive(Error, Debug)]
401#[non_exhaustive]
402pub enum ReadHeaderError {
403    #[error("error reading proxy header")]
404    Io(#[from] io::Error),
405    #[error("invalid utf-8 in proxy header")]
406    InvalidUtf8(#[from] FromUtf8Error),
407    #[error("invalid proxy header json")]
408    InvalidHeader(#[from] serde_json::Error),
409}
410
411impl ProxyHeader {
412    pub async fn read_from_stream(
413        mut stream: impl AsyncRead + Unpin,
414    ) -> Result<Self, ReadHeaderError> {
415        let size = stream.read_i64_le().await?;
416        let mut buf = vec![0u8; size as usize];
417
418        stream.read_exact(&mut buf).await?;
419
420        let header = String::from_utf8(buf)?;
421
422        debug!(?header, "read header");
423
424        Ok(serde_json::from_str(&header)?)
425    }
426}
427
428/// The edge type for an incomming connection.
429#[derive(Copy, Clone, Debug, PartialEq, Eq)]
430pub enum EdgeType {
431    /// EdgeType Undefined
432    Undefined,
433    /// A TCP Edge
434    Tcp,
435    /// A TLS Edge
436    Tls,
437    /// A HTTPs Edge
438    Https,
439}
440
441impl FromStr for EdgeType {
442    type Err = ();
443    fn from_str(s: &str) -> Result<Self, Self::Err> {
444        Ok(match s {
445            "1" => EdgeType::Tcp,
446            "2" => EdgeType::Tls,
447            "3" => EdgeType::Https,
448            _ => EdgeType::Undefined,
449        })
450    }
451}
452
453impl EdgeType {
454    pub(crate) fn as_str(self) -> &'static str {
455        match self {
456            EdgeType::Undefined => "0",
457            EdgeType::Tcp => "1",
458            EdgeType::Tls => "2",
459            EdgeType::Https => "3",
460        }
461    }
462}
463
464impl Serialize for EdgeType {
465    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
466    where
467        S: serde::Serializer,
468    {
469        serializer.serialize_str(self.as_str())
470    }
471}
472
473struct EdgeTypeVisitor;
474
475impl<'de> Visitor<'de> for EdgeTypeVisitor {
476    type Value = EdgeType;
477    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
478        formatter.write_str(r#""0", "1", "2", or "3""#)
479    }
480
481    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
482    where
483        E: serde::de::Error,
484    {
485        Ok(EdgeType::from_str(v).unwrap())
486    }
487}
488
489impl<'de> Deserialize<'de> for EdgeType {
490    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
491    where
492        D: serde::Deserializer<'de>,
493    {
494        deserializer.deserialize_str(EdgeTypeVisitor)
495    }
496}
497
498/// A request from the ngrok dashboard for the agent to stop.
499#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
500#[serde(rename_all = "PascalCase")]
501pub struct Stop {}
502
503/// Common response structure for all remote commands originating from the ngrok
504/// dashboard.
505#[derive(Serialize, Deserialize, Debug, Clone, Default)]
506#[serde(rename_all = "PascalCase")]
507pub struct CommandResp {
508    /// The error arising from command handling, if any.
509    #[serde(default, skip_serializing_if = "Option::is_none")]
510    pub error: Option<String>,
511}
512
513pub type StopResp = CommandResp;
514
515rpc_req!(Stop, StopResp, STOP_REQ);
516
517/// A request from the ngrok dashboard for the agent to restart.
518#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
519#[serde(rename_all = "PascalCase")]
520pub struct Restart {}
521
522pub type RestartResp = CommandResp;
523rpc_req!(Restart, RestartResp, RESTART_REQ);
524
525/// A request from the ngrok dashboard for the agent to update itself.
526#[derive(Serialize, Deserialize, Debug, Clone, Default)]
527#[serde(rename_all = "PascalCase")]
528pub struct Update {
529    /// The version that the agent is requested to update to.
530    pub version: String,
531    /// Whether or not updating to the same major version is sufficient.
532    pub permit_major_version: bool,
533}
534
535/// A request from remote to stop a tunnel
536#[derive(Serialize, Deserialize, Debug, Clone, Default)]
537#[serde(rename_all = "PascalCase")]
538pub struct StopTunnel {
539    /// The id of the tunnel to stop
540    #[serde(rename = "Id")]
541    pub client_id: String,
542    /// The message on why this tunnel was stopped
543    pub message: String,
544    /// An optional ngrok error code
545    pub error_code: Option<String>,
546}
547
548pub type UpdateResp = CommandResp;
549rpc_req!(Update, UpdateResp, UPDATE_REQ);
550
551#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
552#[serde(rename_all = "PascalCase")]
553pub struct SrvInfo {}
554
555#[derive(Serialize, Deserialize, Debug, Clone, Default)]
556#[serde(rename_all = "PascalCase")]
557pub struct SrvInfoResp {
558    pub region: String,
559}
560
561rpc_req!(SrvInfo, SrvInfoResp, SRV_INFO_REQ);
562
563/// The version of [PROXY protocol](https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
564/// to use with this tunnel.
565///
566/// [ProxyProto::None] disables PROXY protocol support.
567#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
568pub enum ProxyProto {
569    /// No PROXY protocol
570    #[default]
571    None,
572    /// PROXY protocol v1
573    V1,
574    /// PROXY protocol v2
575    V2,
576}
577
578impl From<ProxyProto> for i64 {
579    fn from(other: ProxyProto) -> Self {
580        use ProxyProto::*;
581        match other {
582            None => 0,
583            V1 => 1,
584            V2 => 2,
585        }
586    }
587}
588
589impl From<i64> for ProxyProto {
590    fn from(other: i64) -> Self {
591        use ProxyProto::*;
592        match other {
593            1 => V1,
594            2 => V2,
595            _ => None,
596        }
597    }
598}
599
600#[derive(Debug, Clone, Error)]
601#[error("invalid proxyproto string: {}", .0)]
602pub struct InvalidProxyProtoString(String);
603
604impl FromStr for ProxyProto {
605    type Err = InvalidProxyProtoString;
606    fn from_str(s: &str) -> Result<Self, Self::Err> {
607        use ProxyProto::*;
608        Ok(match s {
609            "" => None,
610            "1" => V1,
611            "2" => V2,
612            _ => return Err(InvalidProxyProtoString(s.into())),
613        })
614    }
615}
616
617impl Serialize for ProxyProto {
618    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
619    where
620        S: serde::Serializer,
621    {
622        serializer.serialize_i64(i64::from(*self))
623    }
624}
625
626struct ProxyProtoVisitor;
627
628impl<'de> Visitor<'de> for ProxyProtoVisitor {
629    type Value = ProxyProto;
630    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
631        formatter.write_str("0, 1, or 2")
632    }
633
634    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
635    where
636        E: serde::de::Error,
637    {
638        Ok(ProxyProto::from(v))
639    }
640
641    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
642    where
643        E: serde::de::Error,
644    {
645        Ok(ProxyProto::from(v as i64))
646    }
647}
648
649impl<'de> Deserialize<'de> for ProxyProto {
650    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
651    where
652        D: serde::Deserializer<'de>,
653    {
654        deserializer.deserialize_i64(ProxyProtoVisitor)
655    }
656}
657
658#[derive(Debug, Serialize, Deserialize, Clone)]
659#[serde(untagged)]
660pub enum PolicyWrapper {
661    #[serde(serialize_with = "serialize_policy")]
662    Policy(Policy),
663    String(String),
664}
665
666impl From<String> for PolicyWrapper {
667    fn from(value: String) -> Self {
668        PolicyWrapper::String(value)
669    }
670}
671
672#[derive(Serialize, Deserialize, Debug, Clone, Default)]
673#[serde(rename_all = "PascalCase")]
674pub struct HttpEndpoint {
675    pub hostname: String,
676    pub auth: String,
677    pub subdomain: String,
678    pub host_header_rewrite: bool,
679    pub local_url_scheme: Option<String>,
680    pub proxy_proto: ProxyProto,
681
682    pub compression: Option<Compression>,
683    pub circuit_breaker: Option<CircuitBreaker>,
684    #[serde(rename = "IPRestriction")]
685    pub ip_restriction: Option<IpRestriction>,
686    pub basic_auth: Option<BasicAuth>,
687    #[serde(rename = "OAuth")]
688    pub oauth: Option<Oauth>,
689    #[serde(rename = "OIDC")]
690    pub oidc: Option<Oidc>,
691    pub webhook_verification: Option<WebhookVerification>,
692    #[serde(rename = "MutualTLSCA")]
693    pub mutual_tls_ca: Option<MutualTls>,
694    #[serde(default)]
695    pub request_headers: Option<Headers>,
696    #[serde(default)]
697    pub response_headers: Option<Headers>,
698    #[serde(rename = "WebsocketTCPConverter")]
699    pub websocket_tcp_converter: Option<WebsocketTcpConverter>,
700    #[serde(rename = "UserAgentFilter")]
701    pub user_agent_filter: Option<UserAgentFilter>,
702    #[serde(rename = "TrafficPolicy")]
703    pub traffic_policy: Option<PolicyWrapper>,
704}
705
706#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
707pub struct Compression {}
708
709fn is_default<T>(v: &T) -> bool
710where
711    T: PartialEq<T> + Default,
712{
713    T::default() == *v
714}
715
716#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
717pub struct CircuitBreaker {
718    #[serde(default, skip_serializing_if = "is_default")]
719    pub error_threshold: f64,
720}
721
722#[derive(Debug, Clone, Serialize, Deserialize)]
723pub struct BasicAuth {
724    #[serde(default, skip_serializing_if = "is_default")]
725    pub credentials: Vec<BasicAuthCredential>,
726}
727
728#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
729pub struct BasicAuthCredential {
730    pub username: String,
731    #[serde(default, skip_serializing_if = "is_default")]
732    pub cleartext_password: String,
733    #[serde(default, skip_serializing_if = "is_default")]
734    #[serde(with = "base64bytes")]
735    pub hashed_password: Vec<u8>,
736}
737
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct IpRestriction {
740    #[serde(default, skip_serializing_if = "is_default")]
741    pub allow_cidrs: Vec<String>,
742    #[serde(default, skip_serializing_if = "is_default")]
743    pub deny_cidrs: Vec<String>,
744}
745
746#[derive(Debug, Clone, Serialize, Deserialize)]
747pub struct Oauth {
748    pub provider: String,
749    #[serde(default, skip_serializing_if = "is_default")]
750    pub client_id: String,
751    #[serde(default, skip_serializing_if = "is_default")]
752    pub client_secret: SecretString,
753    #[serde(default, skip_serializing_if = "is_default")]
754    #[serde(with = "base64bytes")]
755    pub sealed_client_secret: Vec<u8>,
756    #[serde(default, skip_serializing_if = "is_default")]
757    pub allow_emails: Vec<String>,
758    #[serde(default, skip_serializing_if = "is_default")]
759    pub allow_domains: Vec<String>,
760    #[serde(default, skip_serializing_if = "is_default")]
761    pub scopes: Vec<String>,
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
765pub struct Oidc {
766    pub issuer_url: String,
767    #[serde(default, skip_serializing_if = "is_default")]
768    pub client_id: String,
769    #[serde(default, skip_serializing_if = "is_default")]
770    pub client_secret: SecretString,
771    #[serde(default, skip_serializing_if = "is_default")]
772    #[serde(with = "base64bytes")]
773    pub sealed_client_secret: Vec<u8>,
774    #[serde(default, skip_serializing_if = "is_default")]
775    pub allow_emails: Vec<String>,
776    #[serde(default, skip_serializing_if = "is_default")]
777    pub allow_domains: Vec<String>,
778    #[serde(default, skip_serializing_if = "is_default")]
779    pub scopes: Vec<String>,
780}
781
782#[derive(Debug, Clone, Serialize, Deserialize)]
783pub struct WebhookVerification {
784    pub provider: String,
785    #[serde(default, skip_serializing_if = "is_default")]
786    pub secret: SecretString,
787    #[serde(default, skip_serializing_if = "is_default")]
788    #[serde(with = "base64bytes")]
789    pub sealed_secret: Vec<u8>,
790}
791
792#[derive(Debug, Clone, Serialize, Deserialize)]
793pub struct MutualTls {
794    #[serde(default, skip_serializing_if = "is_default")]
795    #[serde(with = "base64bytes")]
796    // this is snake-case on the wire
797    pub mutual_tls_ca: Vec<u8>,
798}
799
800#[derive(Debug, Clone, Serialize, Deserialize)]
801#[serde(rename_all = "camelCase")]
802pub struct Headers {
803    #[serde(default, skip_serializing_if = "is_default")]
804    pub add: Vec<String>,
805    #[serde(default, skip_serializing_if = "is_default")]
806    pub remove: Vec<String>,
807    #[serde(default, skip_serializing_if = "is_default")]
808    pub add_parsed: HashMap<String, String>,
809}
810
811#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
812pub struct WebsocketTcpConverter {}
813
814#[derive(Debug, Clone, Serialize, Deserialize)]
815pub struct UserAgentFilter {
816    #[serde(default, skip_serializing_if = "is_default")]
817    pub allow: Vec<String>,
818    #[serde(default, skip_serializing_if = "is_default")]
819    pub deny: Vec<String>,
820}
821
822#[derive(Serialize, Deserialize, Debug, Clone, Default)]
823#[serde(rename_all = "PascalCase")]
824pub struct TcpEndpoint {
825    pub addr: String,
826    pub proxy_proto: ProxyProto,
827    #[serde(rename = "IPRestriction")]
828    pub ip_restriction: Option<IpRestriction>,
829    #[serde(rename = "TrafficPolicy")]
830    pub traffic_policy: Option<PolicyWrapper>,
831}
832
833#[derive(Serialize, Deserialize, Debug, Clone, Default)]
834#[serde(rename_all = "PascalCase")]
835pub struct TlsEndpoint {
836    pub hostname: String,
837    pub subdomain: String,
838    pub proxy_proto: ProxyProto,
839    #[serde(rename = "MutualTLSAtAgent")]
840    pub mutual_tls_at_agent: bool,
841
842    #[serde(rename = "MutualTLSAtEdge")]
843    pub mutual_tls_at_edge: Option<MutualTls>,
844    #[serde(rename = "TLSTermination")]
845    pub tls_termination: Option<TlsTermination>,
846    #[serde(rename = "IPRestriction")]
847    pub ip_restriction: Option<IpRestriction>,
848    #[serde(rename = "TrafficPolicy")]
849    pub traffic_policy: Option<PolicyWrapper>,
850}
851
852#[derive(Serialize, Deserialize, Debug, Clone, Default)]
853pub struct TlsTermination {
854    #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
855    pub cert: Vec<u8>,
856    #[serde(skip_serializing_if = "is_default", default)]
857    pub key: SecretBytes,
858    #[serde(default, with = "base64bytes", skip_serializing_if = "is_default")]
859    pub sealed_key: Vec<u8>,
860}
861
862#[derive(Serialize, Deserialize, Debug, Clone, Default)]
863#[serde(rename_all = "PascalCase")]
864pub struct LabelEndpoint {
865    pub labels: HashMap<String, String>,
866}
867
868#[derive(Serialize, Deserialize, Debug, Clone, Default)]
869#[serde(rename_all = "PascalCase", default)]
870pub struct Policy {
871    pub inbound: Vec<Rule>,
872    pub outbound: Vec<Rule>,
873}
874
875#[derive(Serialize, Deserialize, Debug, Clone, Default)]
876#[serde(rename_all = "PascalCase", default)]
877pub struct Rule {
878    pub name: String,
879    pub expressions: Vec<String>,
880    pub actions: Vec<Action>,
881}
882
883#[derive(Serialize, Deserialize, Debug, Clone, Default)]
884#[serde(rename_all = "PascalCase", default)]
885pub struct Action {
886    #[serde(rename = "Type")]
887    pub type_: String,
888    #[serde(default, with = "vec_to_json", skip_serializing_if = "is_default")]
889    pub config: Vec<u8>,
890}
891
892// This function converts a Policy into a valid JSON string. This is used so legacy configurations will still work
893// using the new string "TrafficPolicy" field.
894fn serialize_policy<S: Serializer>(v: &Policy, s: S) -> Result<S::Ok, S::Error> {
895    let abc = match serde_json::to_string(v) {
896        Ok(t) => t,
897        Err(_) => {
898            return Err(serde::ser::Error::custom(
899                "policy could not be converted to valid json",
900            ))
901        }
902    };
903    s.serialize_str(&abc)
904}
905
906// These are helpers to convert base64 strings to full, real json. The serialize helper also ensures that the resulting
907// representation isn't a string-escaped string.
908mod vec_to_json {
909    use serde::{
910        Deserialize,
911        Deserializer,
912        Serialize,
913        Serializer,
914    };
915
916    pub fn serialize<S: Serializer>(v: &[u8], s: S) -> Result<S::Ok, S::Error> {
917        let u: serde_json::Value = match serde_json::from_slice(v) {
918            Ok(k) => k,
919            Err(_) => return Err(serde::ser::Error::custom("Config is invalid JSON")),
920        };
921
922        u.serialize(s)
923    }
924
925    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
926        let s = serde_json::Map::deserialize(d)?;
927        let v = serde_json::to_vec(&s).unwrap();
928        Ok(v)
929    }
930}
931
932// These are helpers to facilitate the Vec<u8> <-> base64-encoded bytes
933// representation that the Go messages use
934mod base64bytes {
935    use base64::prelude::*;
936    use serde::{
937        Deserialize,
938        Deserializer,
939        Serialize,
940        Serializer,
941    };
942
943    pub fn serialize<S: Serializer>(v: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
944        BASE64_STANDARD.encode(v).serialize(s)
945    }
946
947    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
948        let s = String::deserialize(d)?;
949        BASE64_STANDARD
950            .decode(s.as_bytes())
951            .map_err(serde::de::Error::custom)
952    }
953}
954
955#[cfg(test)]
956mod test {
957
958    use super::*;
959
960    #[test]
961    fn test_proxy_proto_serde() {
962        let input = "2";
963
964        let p: ProxyProto = serde_json::from_str(input).unwrap();
965
966        assert!(matches!(p, ProxyProto::V2));
967
968        assert_eq!(serde_json::to_string(&p).unwrap(), "2");
969    }
970
971    pub(crate) const POLICY_JSON: &str = r###"{"Inbound":[{"Name":"test_in","Expressions":["req.Method == 'PUT'"],"Actions":[{"Type":"deny"}]}],"Outbound":[{"Name":"test_out","Expressions":["res.StatusCode == '200'"],"Actions":[{"Type":"custom-response","Config":{"status_code":201}}]}]}"###;
972
973    #[test]
974    fn test_policy_proto_serde() {
975        let policy: Policy = serde_json::from_str(POLICY_JSON).unwrap();
976
977        // mainly just interested in checking outbound, as that has the
978        // special vec serialization
979        assert_eq!(1, policy.outbound.len());
980        let outbound = &policy.outbound[0];
981        assert_eq!(1, outbound.actions.len());
982        let action = &outbound.actions[0];
983        assert_eq!(r#"{"status_code":201}"#.as_bytes(), action.config);
984
985        assert_eq!(serde_json::to_string(&policy).unwrap(), POLICY_JSON);
986    }
987}