1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 future::Future,
5 io,
6 ops::{
7 Deref,
8 DerefMut,
9 },
10 sync::Arc,
11};
12
13use async_trait::async_trait;
14use muxado::{
15 heartbeat::{
16 HeartbeatConfig,
17 HeartbeatCtl,
18 },
19 typed::{
20 StreamType,
21 TypedAccept,
22 TypedOpenClose,
23 TypedSession,
24 TypedStream,
25 },
26 Error as MuxadoError,
27 SessionBuilder,
28};
29use serde::{
30 de::DeserializeOwned,
31 Deserialize,
32};
33use thiserror::Error;
34use tokio::{
35 io::{
36 AsyncRead,
37 AsyncReadExt,
38 AsyncWrite,
39 AsyncWriteExt,
40 },
41 runtime::Handle,
42};
43use tokio_util::either::Either;
44use tracing::{
45 debug,
46 instrument,
47 warn,
48};
49
50use super::{
51 proto::{
52 Auth,
53 AuthExtra,
54 AuthResp,
55 Bind,
56 BindExtra,
57 BindOpts,
58 BindResp,
59 CommandResp,
60 ErrResp,
61 Error,
62 ProxyHeader,
63 ReadHeaderError,
64 Restart,
65 StartTunnelWithLabel,
66 StartTunnelWithLabelResp,
67 Stop,
68 StopTunnel,
69 Unbind,
70 UnbindResp,
71 Update,
72 PROXY_REQ,
73 RESTART_REQ,
74 STOP_REQ,
75 STOP_TUNNEL_REQ,
76 UPDATE_REQ,
77 VERSION,
78 },
79 rpc::RpcRequest,
80};
81use crate::{
82 tunnel::AcceptError::ListenerClosed,
83 Session,
84};
85
86#[derive(Error, Debug)]
88#[non_exhaustive]
89pub enum RpcError {
90 #[error("failed to open muxado stream")]
92 Open(#[source] MuxadoError),
93 #[error("transport error")]
95 Transport(#[source] MuxadoError),
96 #[error("error sending rpc request")]
98 Send(#[source] io::Error),
99 #[error("error reading rpc response")]
101 Receive(#[source] io::Error),
102 #[error("failed to deserialize rpc response")]
104 InvalidResponse(#[from] serde_json::Error),
105 #[error("rpc error response:\n{0}")]
107 Response(ErrResp),
108}
109
110impl Error for RpcError {
111 fn error_code(&self) -> Option<&str> {
112 match self {
113 RpcError::Response(resp) => resp.error_code(),
114 _ => None,
115 }
116 }
117
118 fn msg(&self) -> String {
119 match self {
120 RpcError::Response(resp) => resp.msg(),
121 _ => format!("{self}"),
122 }
123 }
124}
125
126#[derive(Error, Debug)]
127#[non_exhaustive]
128pub enum StartSessionError {
129 #[error("failed to start heartbeat task")]
130 StartHeartbeat(#[from] io::Error),
131}
132
133#[derive(Error, Debug)]
134#[non_exhaustive]
135pub enum AcceptError {
136 #[error("transport error when accepting connection")]
137 Transport(#[from] MuxadoError),
138 #[error(transparent)]
139 Header(#[from] ReadHeaderError),
140 #[error("invalid stream type: {0}")]
141 InvalidType(StreamType),
142}
143
144pub struct RpcClient {
145 _heartbeat: HeartbeatCtl,
148 open: Box<dyn TypedOpenClose + Send>,
149}
150
151pub struct IncomingStreams {
152 runtime: Handle,
153 handlers: CommandHandlers,
154 pub(crate) session: Option<Session>,
155 accept: Box<dyn TypedAccept + Send>,
156}
157
158pub struct RawSession {
159 client: RpcClient,
160 incoming: IncomingStreams,
161}
162
163impl Deref for RawSession {
164 type Target = RpcClient;
165 fn deref(&self) -> &Self::Target {
166 &self.client
167 }
168}
169
170impl DerefMut for RawSession {
171 fn deref_mut(&mut self) -> &mut Self::Target {
172 &mut self.client
173 }
174}
175
176#[async_trait]
178pub trait CommandHandler<T>: Send + Sync + 'static {
179 async fn handle_command(&self, req: T) -> Result<(), String>;
181}
182
183#[async_trait]
184impl<R, T, F> CommandHandler<R> for T
185where
186 R: Send + 'static,
187 T: Fn(R) -> F + Send + Sync + 'static,
188 F: Future<Output = Result<(), String>> + Send,
189{
190 async fn handle_command(&self, req: R) -> Result<(), String> {
191 self(req).await
192 }
193}
194
195#[derive(Default, Clone)]
196pub struct CommandHandlers {
197 pub on_restart: Option<Arc<dyn CommandHandler<Restart>>>,
198 pub on_update: Option<Arc<dyn CommandHandler<Update>>>,
199 pub on_stop: Option<Arc<dyn CommandHandler<Stop>>>,
200}
201
202impl RawSession {
203 pub async fn start<S, H>(
204 io_stream: S,
205 heartbeat: HeartbeatConfig,
206 handlers: H,
207 ) -> Result<Self, StartSessionError>
208 where
209 S: AsyncRead + AsyncWrite + Send + 'static,
210 H: Into<Option<CommandHandlers>>,
211 {
212 let mux_sess = SessionBuilder::new(io_stream).start();
213
214 let handlers = handlers.into().unwrap_or_default();
215
216 let typed = muxado::typed::Typed::new(mux_sess);
217 let (heartbeat, hbctl) = muxado::heartbeat::Heartbeat::start(typed, heartbeat).await?;
218 let (open, accept) = heartbeat.split_typed();
219
220 let runtime = Handle::current();
221
222 let sess = RawSession {
223 client: RpcClient {
224 _heartbeat: hbctl,
225 open: Box::new(open),
226 },
227 incoming: IncomingStreams {
228 runtime,
229 handlers,
230 session: None,
231 accept: Box::new(accept),
232 },
233 };
234
235 Ok(sess)
236 }
237
238 pub fn split(self) -> (RpcClient, IncomingStreams) {
239 (self.client, self.incoming)
240 }
241}
242
243impl RpcClient {
244 #[instrument(level = "debug", skip(self))]
245 async fn rpc<R: RpcRequest>(&mut self, req: R) -> Result<R::Response, RpcError> {
246 let mut stream = self
247 .open
248 .open_typed(R::TYPE)
249 .await
250 .map_err(RpcError::Open)?;
251 let s = serde_json::to_string(&req)
252 .map_err(io::Error::other)
256 .map_err(RpcError::Send)?;
257
258 stream
259 .write_all(s.as_bytes())
260 .await
261 .map_err(RpcError::Send)?;
262
263 let mut buf = Vec::new();
264 stream
265 .read_to_end(&mut buf)
266 .await
267 .map_err(RpcError::Receive)?;
268
269 #[derive(Debug, Deserialize)]
270 struct ErrResp {
271 #[serde(rename = "Error")]
272 error: String,
273 }
274
275 let ok_resp = serde_json::from_slice::<R::Response>(&buf);
276 let err_resp = serde_json::from_slice::<ErrResp>(&buf);
277
278 if let Ok(err) = err_resp {
279 if !err.error.is_empty() {
280 debug!(?err, "decoded rpc error response");
281 return Err(RpcError::Response(err.error.as_str().into()));
282 }
283 }
284
285 debug!(resp = ?ok_resp, "decoded rpc response");
286
287 Ok(ok_resp?)
288 }
289
290 pub async fn close(&mut self) -> Result<(), RpcError> {
292 self.open
293 .close(MuxadoError::None, "".into())
294 .await
295 .map_err(RpcError::Transport)?;
296 Ok(())
297 }
298
299 #[instrument(level = "debug", skip(self))]
300 pub async fn auth(
301 &mut self,
302 id: impl Into<String> + Debug,
303 extra: AuthExtra,
304 ) -> Result<AuthResp, RpcError> {
305 let id = id.into();
306 let req = Auth {
307 client_id: id.clone(),
308 extra,
309 version: VERSION.iter().map(|&x| x.into()).collect(),
310 };
311
312 let resp = self.rpc(req).await?;
313
314 Ok(resp)
315 }
316
317 #[instrument(level = "debug", skip(self))]
318 pub async fn listen(
319 &mut self,
320 protocol: impl Into<String> + Debug,
321 opts: BindOpts,
322 extra: BindExtra,
323 id: impl Into<String> + Debug,
324 forwards_to: impl Into<String> + Debug,
325 forwards_proto: impl Into<String> + Debug,
326 ) -> Result<BindResp<BindOpts>, RpcError> {
327 macro_rules! match_variant {
336 ($v:expr, $($var:tt),*) => {
337 match opts {
338 $(BindOpts::$var (opts) => {
339 let req = Bind {
340 client_id: id.into(),
341 proto: protocol.into(),
342 forwards_to: forwards_to.into(),
343 forwards_proto: forwards_proto.into(),
344 opts,
345 extra,
346 };
347
348 let resp = self.rpc(req).await?;
349 BindResp {
350 bind_opts: BindOpts::$var(resp.bind_opts),
351 client_id: resp.client_id,
352 url: resp.url,
353 extra: resp.extra,
354 proto: resp.proto,
355 }
356 })*
357 }
358 };
359 }
360 Ok(match_variant!(opts, Http, Tcp, Tls))
361 }
362
363 #[instrument(level = "debug", skip(self))]
364 pub async fn listen_label(
365 &mut self,
366 labels: HashMap<String, String>,
367 metadata: impl Into<String> + Debug,
368 forwards_to: impl Into<String> + Debug,
369 forwards_proto: impl Into<String> + Debug,
370 ) -> Result<StartTunnelWithLabelResp, RpcError> {
371 let req = StartTunnelWithLabel {
372 labels,
373 metadata: metadata.into(),
374 forwards_to: forwards_to.into(),
375 forwards_proto: forwards_proto.into(),
376 };
377
378 self.rpc(req).await
379 }
380
381 #[instrument(level = "debug", skip(self))]
382 pub async fn unlisten(
383 &mut self,
384 id: impl Into<String> + Debug,
385 ) -> Result<UnbindResp, RpcError> {
386 self.rpc(Unbind {
387 client_id: id.into(),
388 })
389 .await
390 }
391}
392
393pub const NOT_IMPLEMENTED: &str = "the agent has not defined a callback for this operation";
394
395async fn read_req<T>(stream: &mut TypedStream) -> Result<T, Either<io::Error, serde_json::Error>>
396where
397 T: DeserializeOwned + Debug + 'static,
398{
399 debug!("reading request from stream");
400 let mut buf = vec![];
401 let req = serde_json::from_value(loop {
402 let mut tmp = vec![0u8; 256];
403 let bytes = stream.read(&mut tmp).await.map_err(Either::Left)?;
404 buf.extend_from_slice(&tmp[..bytes]);
405
406 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(&buf) {
407 break obj;
408 }
409 })
410 .map_err(Either::Right)?;
411 debug!(?req, "read request from stream");
412 Ok(req)
413}
414
415async fn handle_req<T>(
416 handler: Option<Arc<dyn CommandHandler<T>>>,
417 mut stream: TypedStream,
418) -> Result<(), Either<io::Error, serde_json::Error>>
419where
420 T: DeserializeOwned + Debug + 'static,
421{
422 let res = async {
423 let req = read_req(&mut stream).await?;
424 let resp = if let Some(handler) = handler {
425 debug!("running command handler");
426 handler.handle_command(req).await.err()
427 } else {
428 Some(NOT_IMPLEMENTED.into())
429 };
430
431 debug!(?resp, "writing response to stream");
432
433 let resp_json = serde_json::to_vec(&CommandResp { error: resp }).map_err(Either::Right)?;
434
435 stream
436 .write_all(resp_json.as_slice())
437 .await
438 .map_err(Either::Left)?;
439
440 Ok(())
441 }
442 .await;
443
444 if let Err(e) = &res {
445 warn!(?e, "error when handling dashboard command");
446 }
447
448 res
449}
450
451impl IncomingStreams {
452 pub async fn accept(&mut self) -> Result<TunnelStream, AcceptError> {
453 Ok(loop {
454 let mut stream = self.accept.accept_typed().await?;
455
456 match stream.typ() {
457 RESTART_REQ => {
458 self.runtime
459 .spawn(handle_req(self.handlers.on_restart.clone(), stream));
460 }
461 UPDATE_REQ => {
462 self.runtime
463 .spawn(handle_req(self.handlers.on_update.clone(), stream));
464 }
465 STOP_REQ => {
466 self.runtime
467 .spawn(handle_req(self.handlers.on_stop.clone(), stream));
468 }
469 STOP_TUNNEL_REQ => {
470 if let Some(session) = &self.session {
472 let req =
473 read_req::<StopTunnel>(&mut stream)
474 .await
475 .map_err(|e| match e {
476 Either::Left(err) => ReadHeaderError::from(err),
477 Either::Right(err) => ReadHeaderError::from(err),
478 })?;
479 session
480 .close_tunnel_with_error(
481 req.client_id,
482 ListenerClosed {
483 message: req.message,
484 error_code: req.error_code,
485 },
486 )
487 .await;
488 }
489 }
490 PROXY_REQ => {
491 let header = ProxyHeader::read_from_stream(&mut *stream).await?;
492
493 break TunnelStream { header, stream };
494 }
495 t => return Err(AcceptError::InvalidType(t)),
496 }
497 })
498 }
499}
500
501pub struct TunnelStream {
502 pub header: ProxyHeader,
503 pub stream: TypedStream,
504}