1use std::{
9 error::Error as StdError,
10 io,
11 sync::{
12 atomic::{
13 AtomicU64,
14 Ordering,
15 },
16 Arc,
17 },
18 time::Duration,
19};
20
21use async_trait::async_trait;
22use futures::{
23 future::select,
24 prelude::*,
25};
26use tokio::{
27 io::{
28 AsyncReadExt,
29 AsyncWriteExt,
30 },
31 runtime::Handle,
32 select,
33 sync::{
34 mpsc,
35 oneshot,
36 },
37};
38
39use crate::{
40 errors::Error,
41 typed::{
42 StreamType,
43 TypedAccept,
44 TypedOpenClose,
45 TypedSession,
46 TypedStream,
47 },
48};
49
50const HEARTBEAT_TYPE: StreamType = StreamType::clamp(0xFFFFFFFF);
51
52pub struct Heartbeat<S> {
55 runtime: Handle,
56 drop_waiter: awaitdrop::Waiter,
57 typ: StreamType,
58 inner: S,
59}
60
61pub struct HeartbeatCtl {
66 #[allow(dead_code)]
68 dropref: awaitdrop::Ref,
69 durations: Arc<(AtomicU64, AtomicU64)>,
70 on_demand: mpsc::Sender<oneshot::Sender<Duration>>,
71}
72
73#[async_trait]
75pub trait HeartbeatHandler: Send + Sync + 'static {
76 async fn handle_heartbeat(&self, latency: Option<Duration>) -> Result<(), Box<dyn StdError>>;
83}
84
85#[async_trait]
86impl<T, F> HeartbeatHandler for T
87where
88 T: Fn(Option<Duration>) -> F + Send + Sync + 'static,
89 F: Future<Output = Result<(), Box<dyn StdError>>> + Send,
90{
91 async fn handle_heartbeat(&self, latency: Option<Duration>) -> Result<(), Box<dyn StdError>> {
92 self(latency).await
93 }
94}
95
96pub struct HeartbeatConfig {
98 pub interval: Duration,
100 pub tolerance: Duration,
103 pub handler: Option<Arc<dyn HeartbeatHandler>>,
105}
106
107impl Default for HeartbeatConfig {
108 fn default() -> Self {
109 HeartbeatConfig {
110 interval: Duration::from_secs(10),
111 tolerance: Duration::from_secs(15),
112 handler: None,
113 }
114 }
115}
116
117impl<S> Heartbeat<S>
118where
119 S: TypedSession + 'static,
120{
121 pub async fn start(sess: S, cfg: HeartbeatConfig) -> Result<(Self, HeartbeatCtl), io::Error> {
124 let (dropref, drop_waiter) = awaitdrop::awaitdrop();
125
126 let mut hb = Heartbeat {
127 runtime: Handle::current(),
128 drop_waiter: drop_waiter.clone(),
129 typ: HEARTBEAT_TYPE,
130 inner: sess,
131 };
132
133 let (dtx, drx) = mpsc::channel(1);
134 let (mtx, mrx) = mpsc::channel(1);
135 let mut ctl = HeartbeatCtl {
136 dropref,
137 durations: Arc::new((
138 (cfg.interval.as_nanos() as u64).into(),
139 (cfg.tolerance.as_nanos() as u64).into(),
140 )),
141 on_demand: dtx,
142 };
143
144 let stream = hb
145 .inner
146 .open_typed(hb.typ)
147 .await
148 .map_err(|_| io::ErrorKind::ConnectionReset)?;
149
150 ctl.start_requester(stream, drx, mtx, drop_waiter.wait())
151 .await?;
152 ctl.start_check(mrx, cfg.handler, drop_waiter.wait())?;
153
154 Ok((hb, ctl))
155 }
156}
157
158impl HeartbeatCtl {
159 pub async fn beat(&self) -> Result<Duration, io::Error> {
161 let (tx, rx) = oneshot::channel();
162 self.on_demand
163 .send(tx)
164 .await
165 .map_err(|_| io::ErrorKind::NotConnected)?;
166 rx.await.map_err(|_| io::ErrorKind::ConnectionReset.into())
167 }
168
169 pub fn set_interval(&self, interval: Duration) {
171 self.durations
172 .0
173 .store(interval.as_nanos() as u64, Ordering::Relaxed);
174 }
175
176 pub fn set_tolerance(&self, tolerance: Duration) {
178 self.durations
179 .1
180 .store(tolerance.as_nanos() as u64, Ordering::Relaxed);
181 }
182
183 fn start_check(
184 &mut self,
185 mut mark: mpsc::Receiver<Duration>,
186 cb: Option<Arc<dyn HeartbeatHandler>>,
187 dropped: awaitdrop::WaitFuture,
188 ) -> Result<(), io::Error> {
189 let (mut interval, mut tolerance) = self.get_durations();
190 let durations = self.durations.clone();
191
192 tokio::spawn(
193 select(
194 async move {
195 let mut deadline = tokio::time::Instant::now() + interval + tolerance;
196 loop {
197 match tokio::time::timeout_at(deadline, mark.recv()).await {
198 Err(_e) => {
199 if let Some(cb) = cb.as_ref() {
200 cb.handle_heartbeat(None).await?;
201 }
202 }
203 Ok(Some(lat)) => {
204 if let Some(cb) = cb.as_ref() {
205 cb.handle_heartbeat(lat.into()).await?;
206 }
207 }
208 Ok(None) => {
209 return Result::<(), Box<dyn StdError>>::Ok(());
210 }
211 };
212
213 (interval, tolerance) = get_durations(&durations);
220 deadline = tokio::time::Instant::now() + interval + tolerance;
221 }
222 }
223 .boxed(),
224 dropped,
225 )
226 .then(|_| async move {
227 tracing::debug!("check exited");
228 }),
229 );
230
231 Ok(())
232 }
233
234 async fn start_requester(
235 &mut self,
236 mut stream: TypedStream,
237 mut on_demand: mpsc::Receiver<oneshot::Sender<Duration>>,
238 mark: mpsc::Sender<Duration>,
239 drop_waiter: awaitdrop::WaitFuture,
240 ) -> Result<(), io::Error> {
241 let (interval, _) = self.get_durations();
242 let mut ticker = tokio::time::interval(interval);
243
244 tokio::spawn(
245 select(
246 async move {
247 loop {
248 let mut resp_chan: Option<oneshot::Sender<Duration>> = None;
249
250 select! {
251 c = on_demand.recv() => if c.is_none() {
254 ticker.tick().await;
255 } else {
256 resp_chan = c;
257 },
258 _ = ticker.tick() => {},
259 }
260
261 tracing::debug!("sending heartbeat");
262
263 let start = std::time::Instant::now();
264 let id: i32 = rand::random();
265
266 if stream.write_all(&id.to_be_bytes()[..]).await.is_err() {
267 return;
268 }
269
270 let mut resp_bytes = [0u8; 4];
271
272 tracing::debug!("waiting for response");
273
274 if stream.read_exact(&mut resp_bytes[..]).await.is_err() {
275 tracing::debug!("error reading response");
276 return;
277 }
278
279 tracing::debug!("got response");
280
281 let resp_id = i32::from_be_bytes(resp_bytes);
282
283 if id != resp_id {
284 return;
285 }
286
287 let latency = std::time::Instant::now() - start;
288
289 if let Some(resp_chan) = resp_chan {
290 let _ = resp_chan.send(latency);
291 } else {
292 let _ = mark.send(latency).await;
293 }
294 }
295 }
296 .boxed(),
297 drop_waiter,
298 )
299 .then(|_| async move {
300 tracing::debug!("requester exited");
301 }),
302 );
303
304 Ok(())
305 }
306
307 fn get_durations(&self) -> (Duration, Duration) {
308 get_durations(&self.durations)
309 }
310}
311
312fn start_responder(rt: &Handle, mut stream: TypedStream, drop_waiter: awaitdrop::WaitFuture) {
313 rt.spawn(select(
314 async move {
315 loop {
316 let mut buf = [0u8; 4];
317 if let Err(e) = stream.read(&mut buf[..]).await {
318 tracing::debug!(?e, "heartbeat responder exiting");
319 return;
320 }
321 if let Err(e) = stream.write_all(&buf[..]).await {
322 tracing::debug!(?e, "heartbeat responder exiting");
323 return;
324 }
325 }
326 }
327 .boxed(),
328 drop_waiter,
329 ));
330}
331
332#[async_trait]
333impl<S> TypedAccept for Heartbeat<S>
334where
335 S: TypedAccept + Send,
336{
337 async fn accept_typed(&mut self) -> Result<TypedStream, Error> {
338 loop {
339 let stream = self.inner.accept_typed().await?;
340 let typ = stream.typ();
341
342 if typ == self.typ {
343 start_responder(&self.runtime, stream, self.drop_waiter.wait());
344 continue;
345 }
346
347 return Ok(stream);
348 }
349 }
350}
351
352#[async_trait]
353impl<S> TypedOpenClose for Heartbeat<S>
354where
355 S: TypedOpenClose + Send,
356{
357 async fn open_typed(&mut self, typ: StreamType) -> Result<TypedStream, Error> {
358 if typ == self.typ {
360 return Err(Error::StreamRefused);
361 }
362
363 self.inner.open_typed(typ).await
364 }
365
366 async fn close(&mut self, error: Error, msg: String) -> Result<(), Error> {
367 self.inner.close(error, msg).await
368 }
369}
370
371impl<S> TypedSession for Heartbeat<S>
372where
373 S: TypedSession + Send,
374 S::TypedAccept: Send,
375 S::TypedOpen: Send,
376{
377 type TypedAccept = Heartbeat<S::TypedAccept>;
378 type TypedOpen = Heartbeat<S::TypedOpen>;
379
380 fn split_typed(self) -> (Self::TypedOpen, Self::TypedAccept) {
381 let drop_waiter = self.drop_waiter;
382 let typ = self.typ;
383 let runtime = self.runtime;
384 let (open, accept) = self.inner.split_typed();
385 (
386 Heartbeat {
387 runtime: runtime.clone(),
388 drop_waiter: drop_waiter.clone(),
389 typ,
390 inner: open,
391 },
392 Heartbeat {
393 runtime,
394 drop_waiter,
395 typ,
396 inner: accept,
397 },
398 )
399 }
400}
401
402fn get_durations(durations: &Arc<(AtomicU64, AtomicU64)>) -> (Duration, Duration) {
403 (
404 Duration::from_nanos(durations.0.load(Ordering::Relaxed)),
405 Duration::from_nanos(durations.1.load(Ordering::Relaxed)),
406 )
407}