muxado/
heartbeat.rs

1//! Heartbeating [TypedSession] wrapper.
2//!
3//! This can be used to wrap a [TypedSession] to provide heartbeating
4//! functionality. The wrapper will start a background task to send heartbeats
5//! to the remote via a dedicated heartbeat stream. It will also accept incoming
6//! heartbeat streams and start a task to reply to them.
7
8use std::{
9    error::Error as StdError,
10    io,
11    sync::{
12        atomic::{
13            AtomicU64,
14            Ordering,
15        },
16        Arc,
17    },
18    time::Duration,
19};
20
21use async_trait::async_trait;
22use futures::{
23    future::select,
24    prelude::*,
25};
26use tokio::{
27    io::{
28        AsyncReadExt,
29        AsyncWriteExt,
30    },
31    runtime::Handle,
32    select,
33    sync::{
34        mpsc,
35        oneshot,
36    },
37};
38
39use crate::{
40    errors::Error,
41    typed::{
42        StreamType,
43        TypedAccept,
44        TypedOpenClose,
45        TypedSession,
46        TypedStream,
47    },
48};
49
50const HEARTBEAT_TYPE: StreamType = StreamType::clamp(0xFFFFFFFF);
51
52/// Wrapper for a muxado [TypedSession] that adds heartbeating over a dedicated
53/// typed stream.
54pub struct Heartbeat<S> {
55    runtime: Handle,
56    drop_waiter: awaitdrop::Waiter,
57    typ: StreamType,
58    inner: S,
59}
60
61/// Controller for the heartbeat task.
62///
63/// Allows owners to change the heartbeat timing at runtime and to explicitly
64/// request heartbeats. When dropped, cancels the heartbeat tasks.
65pub struct HeartbeatCtl {
66    // Implicitly used to cancel the heartbeat tasks.
67    #[allow(dead_code)]
68    dropref: awaitdrop::Ref,
69    durations: Arc<(AtomicU64, AtomicU64)>,
70    on_demand: mpsc::Sender<oneshot::Sender<Duration>>,
71}
72
73/// A handler called on every heartbeat with the latency for that beat.
74#[async_trait]
75pub trait HeartbeatHandler: Send + Sync + 'static {
76    /// Handle the heartbeat
77    ///
78    /// A `None` latency implies that the timeout was reached before the
79    /// heartbeat reply was received.
80    ///
81    /// If this returns an error, the heartbeat task will exit.
82    async fn handle_heartbeat(&self, latency: Option<Duration>) -> Result<(), Box<dyn StdError>>;
83}
84
85#[async_trait]
86impl<T, F> HeartbeatHandler for T
87where
88    T: Fn(Option<Duration>) -> F + Send + Sync + 'static,
89    F: Future<Output = Result<(), Box<dyn StdError>>> + Send,
90{
91    async fn handle_heartbeat(&self, latency: Option<Duration>) -> Result<(), Box<dyn StdError>> {
92        self(latency).await
93    }
94}
95
96/// The heartbeat task configuration.
97pub struct HeartbeatConfig {
98    /// The interval on which heartbeats will be sent.
99    pub interval: Duration,
100    /// The amount of time past a missed heartbeat that the other side will be
101    /// considered dead.
102    pub tolerance: Duration,
103    /// An optional callback to run when a heartbeat is received.
104    pub handler: Option<Arc<dyn HeartbeatHandler>>,
105}
106
107impl Default for HeartbeatConfig {
108    fn default() -> Self {
109        HeartbeatConfig {
110            interval: Duration::from_secs(10),
111            tolerance: Duration::from_secs(15),
112            handler: None,
113        }
114    }
115}
116
117impl<S> Heartbeat<S>
118where
119    S: TypedSession + 'static,
120{
121    /// Wrap a typed session and start the heartbeat task.
122    /// Returns an error if the stream can't be opened.
123    pub async fn start(sess: S, cfg: HeartbeatConfig) -> Result<(Self, HeartbeatCtl), io::Error> {
124        let (dropref, drop_waiter) = awaitdrop::awaitdrop();
125
126        let mut hb = Heartbeat {
127            runtime: Handle::current(),
128            drop_waiter: drop_waiter.clone(),
129            typ: HEARTBEAT_TYPE,
130            inner: sess,
131        };
132
133        let (dtx, drx) = mpsc::channel(1);
134        let (mtx, mrx) = mpsc::channel(1);
135        let mut ctl = HeartbeatCtl {
136            dropref,
137            durations: Arc::new((
138                (cfg.interval.as_nanos() as u64).into(),
139                (cfg.tolerance.as_nanos() as u64).into(),
140            )),
141            on_demand: dtx,
142        };
143
144        let stream = hb
145            .inner
146            .open_typed(hb.typ)
147            .await
148            .map_err(|_| io::ErrorKind::ConnectionReset)?;
149
150        ctl.start_requester(stream, drx, mtx, drop_waiter.wait())
151            .await?;
152        ctl.start_check(mrx, cfg.handler, drop_waiter.wait())?;
153
154        Ok((hb, ctl))
155    }
156}
157
158impl HeartbeatCtl {
159    /// Explicitly request a heartbeat and return the latency.
160    pub async fn beat(&self) -> Result<Duration, io::Error> {
161        let (tx, rx) = oneshot::channel();
162        self.on_demand
163            .send(tx)
164            .await
165            .map_err(|_| io::ErrorKind::NotConnected)?;
166        rx.await.map_err(|_| io::ErrorKind::ConnectionReset.into())
167    }
168
169    /// Change the heartbeat interval.
170    pub fn set_interval(&self, interval: Duration) {
171        self.durations
172            .0
173            .store(interval.as_nanos() as u64, Ordering::Relaxed);
174    }
175
176    /// Change the heartbeat tolerance.
177    pub fn set_tolerance(&self, tolerance: Duration) {
178        self.durations
179            .1
180            .store(tolerance.as_nanos() as u64, Ordering::Relaxed);
181    }
182
183    fn start_check(
184        &mut self,
185        mut mark: mpsc::Receiver<Duration>,
186        cb: Option<Arc<dyn HeartbeatHandler>>,
187        dropped: awaitdrop::WaitFuture,
188    ) -> Result<(), io::Error> {
189        let (mut interval, mut tolerance) = self.get_durations();
190        let durations = self.durations.clone();
191
192        tokio::spawn(
193            select(
194                async move {
195                    let mut deadline = tokio::time::Instant::now() + interval + tolerance;
196                    loop {
197                        match tokio::time::timeout_at(deadline, mark.recv()).await {
198                            Err(_e) => {
199                                if let Some(cb) = cb.as_ref() {
200                                    cb.handle_heartbeat(None).await?;
201                                }
202                            }
203                            Ok(Some(lat)) => {
204                                if let Some(cb) = cb.as_ref() {
205                                    cb.handle_heartbeat(lat.into()).await?;
206                                }
207                            }
208                            Ok(None) => {
209                                return Result::<(), Box<dyn StdError>>::Ok(());
210                            }
211                        };
212
213                        // Slight divergence from Go implementation: this didn't
214                        // previously happen in the "timeout" case, which did noting but
215                        // the callback. Presumably, this usually killed the connection,
216                        // causing the goroutine to exit *anyway*. If we didn't reset
217                        // the deadline here, it would timeout immediately rather than
218                        // blocking indefinitely as in Go.
219                        (interval, tolerance) = get_durations(&durations);
220                        deadline = tokio::time::Instant::now() + interval + tolerance;
221                    }
222                }
223                .boxed(),
224                dropped,
225            )
226            .then(|_| async move {
227                tracing::debug!("check exited");
228            }),
229        );
230
231        Ok(())
232    }
233
234    async fn start_requester(
235        &mut self,
236        mut stream: TypedStream,
237        mut on_demand: mpsc::Receiver<oneshot::Sender<Duration>>,
238        mark: mpsc::Sender<Duration>,
239        drop_waiter: awaitdrop::WaitFuture,
240    ) -> Result<(), io::Error> {
241        let (interval, _) = self.get_durations();
242        let mut ticker = tokio::time::interval(interval);
243
244        tokio::spawn(
245            select(
246                async move {
247                    loop {
248                        let mut resp_chan: Option<oneshot::Sender<Duration>> = None;
249
250                        select! {
251                            // If on_demand is closed, this will return None
252                            // immediately. In that case, wait on the next tick instead.
253                            c = on_demand.recv() => if c.is_none() {
254                                ticker.tick().await;
255                            } else {
256                                resp_chan = c;
257                            },
258                            _ = ticker.tick() => {},
259                        }
260
261                        tracing::debug!("sending heartbeat");
262
263                        let start = std::time::Instant::now();
264                        let id: i32 = rand::random();
265
266                        if stream.write_all(&id.to_be_bytes()[..]).await.is_err() {
267                            return;
268                        }
269
270                        let mut resp_bytes = [0u8; 4];
271
272                        tracing::debug!("waiting for response");
273
274                        if stream.read_exact(&mut resp_bytes[..]).await.is_err() {
275                            tracing::debug!("error reading response");
276                            return;
277                        }
278
279                        tracing::debug!("got response");
280
281                        let resp_id = i32::from_be_bytes(resp_bytes);
282
283                        if id != resp_id {
284                            return;
285                        }
286
287                        let latency = std::time::Instant::now() - start;
288
289                        if let Some(resp_chan) = resp_chan {
290                            let _ = resp_chan.send(latency);
291                        } else {
292                            let _ = mark.send(latency).await;
293                        }
294                    }
295                }
296                .boxed(),
297                drop_waiter,
298            )
299            .then(|_| async move {
300                tracing::debug!("requester exited");
301            }),
302        );
303
304        Ok(())
305    }
306
307    fn get_durations(&self) -> (Duration, Duration) {
308        get_durations(&self.durations)
309    }
310}
311
312fn start_responder(rt: &Handle, mut stream: TypedStream, drop_waiter: awaitdrop::WaitFuture) {
313    rt.spawn(select(
314        async move {
315            loop {
316                let mut buf = [0u8; 4];
317                if let Err(e) = stream.read(&mut buf[..]).await {
318                    tracing::debug!(?e, "heartbeat responder exiting");
319                    return;
320                }
321                if let Err(e) = stream.write_all(&buf[..]).await {
322                    tracing::debug!(?e, "heartbeat responder exiting");
323                    return;
324                }
325            }
326        }
327        .boxed(),
328        drop_waiter,
329    ));
330}
331
332#[async_trait]
333impl<S> TypedAccept for Heartbeat<S>
334where
335    S: TypedAccept + Send,
336{
337    async fn accept_typed(&mut self) -> Result<TypedStream, Error> {
338        loop {
339            let stream = self.inner.accept_typed().await?;
340            let typ = stream.typ();
341
342            if typ == self.typ {
343                start_responder(&self.runtime, stream, self.drop_waiter.wait());
344                continue;
345            }
346
347            return Ok(stream);
348        }
349    }
350}
351
352#[async_trait]
353impl<S> TypedOpenClose for Heartbeat<S>
354where
355    S: TypedOpenClose + Send,
356{
357    async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error> {
358        // Don't open a heartbeat stream manually
359        if typ == self.typ {
360            return Err(Error::StreamRefused);
361        }
362
363        self.inner.open_typed(typ).await
364    }
365
366    async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
367        self.inner.close(error, msg).await
368    }
369}
370
371impl<S> TypedSession for Heartbeat<S>
372where
373    S: TypedSession + Send,
374    S::TypedAccept: Send,
375    S::TypedOpen: Send,
376{
377    type TypedAccept = Heartbeat<S::TypedAccept>;
378    type TypedOpen = Heartbeat<S::TypedOpen>;
379
380    fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept) {
381        let drop_waiter = self.drop_waiter;
382        let typ = self.typ;
383        let runtime = self.runtime;
384        let (open, accept) = self.inner.split_typed();
385        (
386            Heartbeat {
387                runtime: runtime.clone(),
388                drop_waiter: drop_waiter.clone(),
389                typ,
390                inner: open,
391            },
392            Heartbeat {
393                runtime,
394                drop_waiter,
395                typ,
396                inner: accept,
397            },
398        )
399    }
400}
401
402fn get_durations(durations: &Arc<(AtomicU64, AtomicU64)>) -> (Duration, Duration) {
403    (
404        Duration::from_nanos(durations.0.load(Ordering::Relaxed)),
405        Duration::from_nanos(durations.1.load(Ordering::Relaxed)),
406    )
407}