ngrok/internals/
raw_session.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    future::Future,
5    io,
6    ops::{
7        Deref,
8        DerefMut,
9    },
10    sync::Arc,
11};
12
13use async_trait::async_trait;
14use muxado::{
15    heartbeat::{
16        HeartbeatConfig,
17        HeartbeatCtl,
18    },
19    typed::{
20        StreamType,
21        TypedAccept,
22        TypedOpenClose,
23        TypedSession,
24        TypedStream,
25    },
26    Error as MuxadoError,
27    SessionBuilder,
28};
29use serde::{
30    de::DeserializeOwned,
31    Deserialize,
32};
33use thiserror::Error;
34use tokio::{
35    io::{
36        AsyncRead,
37        AsyncReadExt,
38        AsyncWrite,
39        AsyncWriteExt,
40    },
41    runtime::Handle,
42};
43use tokio_util::either::Either;
44use tracing::{
45    debug,
46    instrument,
47    warn,
48};
49
50use super::{
51    proto::{
52        Auth,
53        AuthExtra,
54        AuthResp,
55        Bind,
56        BindExtra,
57        BindOpts,
58        BindResp,
59        CommandResp,
60        ErrResp,
61        Error,
62        ProxyHeader,
63        ReadHeaderError,
64        Restart,
65        StartTunnelWithLabel,
66        StartTunnelWithLabelResp,
67        Stop,
68        StopTunnel,
69        Unbind,
70        UnbindResp,
71        Update,
72        PROXY_REQ,
73        RESTART_REQ,
74        STOP_REQ,
75        STOP_TUNNEL_REQ,
76        UPDATE_REQ,
77        VERSION,
78    },
79    rpc::RpcRequest,
80};
81use crate::{
82    tunnel::AcceptError::ListenerClosed,
83    Session,
84};
85
86/// Errors arising from tunneling protocol RPC calls.
87#[derive(Error, Debug)]
88#[non_exhaustive]
89pub enum RpcError {
90    /// Failed to open a new stream to start the RPC call.
91    #[error("failed to open muxado stream")]
92    Open(#[source] MuxadoError),
93    /// Some non-Open transport error occurred
94    #[error("transport error")]
95    Transport(#[source] MuxadoError),
96    /// Failed to send the request over the stream.
97    #[error("error sending rpc request")]
98    Send(#[source] io::Error),
99    /// Failed to read the RPC response from the stream.
100    #[error("error reading rpc response")]
101    Receive(#[source] io::Error),
102    /// The RPC response was invalid.
103    #[error("failed to deserialize rpc response")]
104    InvalidResponse(#[from] serde_json::Error),
105    /// There was an error in the RPC response.
106    #[error("rpc error response:\n{0}")]
107    Response(ErrResp),
108}
109
110impl Error for RpcError {
111    fn error_code(&self) -> Option<&str> {
112        match self {
113            RpcError::Response(resp) => resp.error_code(),
114            _ => None,
115        }
116    }
117
118    fn msg(&self) -> String {
119        match self {
120            RpcError::Response(resp) => resp.msg(),
121            _ => format!("{self}"),
122        }
123    }
124}
125
126#[derive(Error, Debug)]
127#[non_exhaustive]
128pub enum StartSessionError {
129    #[error("failed to start heartbeat task")]
130    StartHeartbeat(#[from] io::Error),
131}
132
133#[derive(Error, Debug)]
134#[non_exhaustive]
135pub enum AcceptError {
136    #[error("transport error when accepting connection")]
137    Transport(#[from] MuxadoError),
138    #[error(transparent)]
139    Header(#[from] ReadHeaderError),
140    #[error("invalid stream type: {0}")]
141    InvalidType(StreamType),
142}
143
144pub struct RpcClient {
145    // This is held so that the heartbeat task doesn't get shutdown. Eventually
146    // we may use it to request heartbeats via the `Session`.
147    _heartbeat: HeartbeatCtl,
148    open: Box<dyn TypedOpenClose + Send>,
149}
150
151pub struct IncomingStreams {
152    runtime: Handle,
153    handlers: CommandHandlers,
154    pub(crate) session: Option<Session>,
155    accept: Box<dyn TypedAccept + Send>,
156}
157
158pub struct RawSession {
159    client: RpcClient,
160    incoming: IncomingStreams,
161}
162
163impl Deref for RawSession {
164    type Target = RpcClient;
165    fn deref(&self) -> &Self::Target {
166        &self.client
167    }
168}
169
170impl DerefMut for RawSession {
171    fn deref_mut(&mut self) -> &mut Self::Target {
172        &mut self.client
173    }
174}
175
176/// Trait for a type that can handle a command from the ngrok dashboard.
177#[async_trait]
178pub trait CommandHandler<T>: Send + Sync + 'static {
179    /// Handle the remote command.
180    async fn handle_command(&self, req: T) -> Result<(), String>;
181}
182
183#[async_trait]
184impl<R, T, F> CommandHandler<R> for T
185where
186    R: Send + 'static,
187    T: Fn(R) -> F + Send + Sync + 'static,
188    F: Future<Output = Result<(), String>> + Send,
189{
190    async fn handle_command(&self, req: R) -> Result<(), String> {
191        self(req).await
192    }
193}
194
195#[derive(Default, Clone)]
196pub struct CommandHandlers {
197    pub on_restart: Option<Arc<dyn CommandHandler<Restart>>>,
198    pub on_update: Option<Arc<dyn CommandHandler<Update>>>,
199    pub on_stop: Option<Arc<dyn CommandHandler<Stop>>>,
200}
201
202impl RawSession {
203    pub async fn start<S, H>(
204        io_stream: S,
205        heartbeat: HeartbeatConfig,
206        handlers: H,
207    ) -> Result<Self, StartSessionError>
208    where
209        S: AsyncRead + AsyncWrite + Send + 'static,
210        H: Into<Option<CommandHandlers>>,
211    {
212        let mux_sess = SessionBuilder::new(io_stream).start();
213
214        let handlers = handlers.into().unwrap_or_default();
215
216        let typed = muxado::typed::Typed::new(mux_sess);
217        let (heartbeat, hbctl) = muxado::heartbeat::Heartbeat::start(typed, heartbeat).await?;
218        let (open, accept) = heartbeat.split_typed();
219
220        let runtime = Handle::current();
221
222        let sess = RawSession {
223            client: RpcClient {
224                _heartbeat: hbctl,
225                open: Box::new(open),
226            },
227            incoming: IncomingStreams {
228                runtime,
229                handlers,
230                session: None,
231                accept: Box::new(accept),
232            },
233        };
234
235        Ok(sess)
236    }
237
238    pub fn split(self) -> (RpcClient, IncomingStreams) {
239        (self.client, self.incoming)
240    }
241}
242
243impl RpcClient {
244    #[instrument(level = "debug", skip(self))]
245    async fn rpc<R: RpcRequest>(&mut self, req: R) -> Result<R::Response, RpcError> {
246        let mut stream = self
247            .open
248            .open_typed(R::TYPE)
249            .await
250            .map_err(RpcError::Open)?;
251        let s = serde_json::to_string(&req)
252            // This should never happen, since we control the request types and
253            // know that they will always serialize correctly. Just in case
254            // though, call them "Send" errors.
255            .map_err(io::Error::other)
256            .map_err(RpcError::Send)?;
257
258        stream
259            .write_all(s.as_bytes())
260            .await
261            .map_err(RpcError::Send)?;
262
263        let mut buf = Vec::new();
264        stream
265            .read_to_end(&mut buf)
266            .await
267            .map_err(RpcError::Receive)?;
268
269        #[derive(Debug, Deserialize)]
270        struct ErrResp {
271            #[serde(rename = "Error")]
272            error: String,
273        }
274
275        let ok_resp = serde_json::from_slice::<R::Response>(&buf);
276        let err_resp = serde_json::from_slice::<ErrResp>(&buf);
277
278        if let Ok(err) = err_resp {
279            if !err.error.is_empty() {
280                debug!(?err, "decoded rpc error response");
281                return Err(RpcError::Response(err.error.as_str().into()));
282            }
283        }
284
285        debug!(resp = ?ok_resp, "decoded rpc response");
286
287        Ok(ok_resp?)
288    }
289
290    /// Close the raw ngrok session with a "None" muxado error.
291    pub async fn close(&mut self) -> Result<(), RpcError> {
292        self.open
293            .close(MuxadoError::None, "".into())
294            .await
295            .map_err(RpcError::Transport)?;
296        Ok(())
297    }
298
299    #[instrument(level = "debug", skip(self))]
300    pub async fn auth(
301        &mut self,
302        id: impl Into<String> + Debug,
303        extra: AuthExtra,
304    ) -> Result<AuthResp, RpcError> {
305        let id = id.into();
306        let req = Auth {
307            client_id: id.clone(),
308            extra,
309            version: VERSION.iter().map(|&x| x.into()).collect(),
310        };
311
312        let resp = self.rpc(req).await?;
313
314        Ok(resp)
315    }
316
317    #[instrument(level = "debug", skip(self))]
318    pub async fn listen(
319        &mut self,
320        protocol: impl Into<String> + Debug,
321        opts: BindOpts,
322        extra: BindExtra,
323        id: impl Into<String> + Debug,
324        forwards_to: impl Into<String> + Debug,
325        forwards_proto: impl Into<String> + Debug,
326    ) -> Result<BindResp<BindOpts>, RpcError> {
327        // Sorry, this is awful. Serde untagged unions are pretty fraught and
328        // hard to debug, so we're using this macro to specialize this call
329        // based on the enum variant. It drops down to the type wrapped in the
330        // enum for the actual request/response, and then re-wraps it on the way
331        // back out in the same variant.
332        // It's probably an artifact of the go -> rust translation, and could be
333        // fixed with enough refactoring and rearchitecting. But it works well
334        // enough for now and is pretty localized.
335        macro_rules! match_variant {
336            ($v:expr, $($var:tt),*) => {
337                match opts {
338                    $(BindOpts::$var (opts) => {
339                        let req = Bind {
340                            client_id: id.into(),
341                            proto: protocol.into(),
342                            forwards_to: forwards_to.into(),
343                            forwards_proto: forwards_proto.into(),
344                            opts,
345                            extra,
346                        };
347
348                        let resp = self.rpc(req).await?;
349                        BindResp {
350                            bind_opts: BindOpts::$var(resp.bind_opts),
351                            client_id: resp.client_id,
352                            url: resp.url,
353                            extra: resp.extra,
354                            proto: resp.proto,
355                        }
356                    })*
357                }
358            };
359        }
360        Ok(match_variant!(opts, Http, Tcp, Tls))
361    }
362
363    #[instrument(level = "debug", skip(self))]
364    pub async fn listen_label(
365        &mut self,
366        labels: HashMap<String, String>,
367        metadata: impl Into<String> + Debug,
368        forwards_to: impl Into<String> + Debug,
369        forwards_proto: impl Into<String> + Debug,
370    ) -> Result<StartTunnelWithLabelResp, RpcError> {
371        let req = StartTunnelWithLabel {
372            labels,
373            metadata: metadata.into(),
374            forwards_to: forwards_to.into(),
375            forwards_proto: forwards_proto.into(),
376        };
377
378        self.rpc(req).await
379    }
380
381    #[instrument(level = "debug", skip(self))]
382    pub async fn unlisten(
383        &mut self,
384        id: impl Into<String> + Debug,
385    ) -> Result<UnbindResp, RpcError> {
386        self.rpc(Unbind {
387            client_id: id.into(),
388        })
389        .await
390    }
391}
392
393pub const NOT_IMPLEMENTED: &str = "the agent has not defined a callback for this operation";
394
395async fn read_req<T>(stream: &mut TypedStream) -> Result<T, Either<io::Error, serde_json::Error>>
396where
397    T: DeserializeOwned + Debug + 'static,
398{
399    debug!("reading request from stream");
400    let mut buf = vec![];
401    let req = serde_json::from_value(loop {
402        let mut tmp = vec![0u8; 256];
403        let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?;
404        buf.extend_from_slice(&tmp[..bytes]);
405
406        if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(&buf) {
407            break obj;
408        }
409    })
410    .map_err(Either::Right)?;
411    debug!(?req, "read request from stream");
412    Ok(req)
413}
414
415async fn handle_req<T>(
416    handler: Option<Arc<dyn CommandHandler<T>>>,
417    mut stream: TypedStream,
418) -> Result<(), Either<io::Error, serde_json::Error>>
419where
420    T: DeserializeOwned + Debug + 'static,
421{
422    let res = async {
423        let req = read_req(&mut stream).await?;
424        let resp = if let Some(handler) = handler {
425            debug!("running command handler");
426            handler.handle_command(req).await.err()
427        } else {
428            Some(NOT_IMPLEMENTED.into())
429        };
430
431        debug!(?resp, "writing response to stream");
432
433        let resp_json = serde_json::to_vec(&CommandResp { error: resp }).map_err(Either::Right)?;
434
435        stream
436            .write_all(resp_json.as_slice())
437            .await
438            .map_err(Either::Left)?;
439
440        Ok(())
441    }
442    .await;
443
444    if let Err(e) = &res {
445        warn!(?e, "error when handling dashboard command");
446    }
447
448    res
449}
450
451impl IncomingStreams {
452    pub async fn accept(&mut self) -> Result<TunnelStream, AcceptError> {
453        Ok(loop {
454            let mut stream = self.accept.accept_typed().await?;
455
456            match stream.typ() {
457                RESTART_REQ => {
458                    self.runtime
459                        .spawn(handle_req(self.handlers.on_restart.clone(), stream));
460                }
461                UPDATE_REQ => {
462                    self.runtime
463                        .spawn(handle_req(self.handlers.on_update.clone(), stream));
464                }
465                STOP_REQ => {
466                    self.runtime
467                        .spawn(handle_req(self.handlers.on_stop.clone(), stream));
468                }
469                STOP_TUNNEL_REQ => {
470                    // close the tunnel through the session
471                    if let Some(session) = &self.session {
472                        let req =
473                            read_req::<StopTunnel>(&mut stream)
474                                .await
475                                .map_err(|e| match e {
476                                    Either::Left(err) => ReadHeaderError::from(err),
477                                    Either::Right(err) => ReadHeaderError::from(err),
478                                })?;
479                        session
480                            .close_tunnel_with_error(
481                                req.client_id,
482                                ListenerClosed {
483                                    message: req.message,
484                                    error_code: req.error_code,
485                                },
486                            )
487                            .await;
488                    }
489                }
490                PROXY_REQ => {
491                    let header = ProxyHeader::read_from_stream(&mut *stream).await?;
492
493                    break TunnelStream { header, stream };
494                }
495                t => return Err(AcceptError::InvalidType(t)),
496            }
497        })
498    }
499}
500
501pub struct TunnelStream {
502    pub header: ProxyHeader,
503    pub stream: TypedStream,
504}