diff --git a/src/bastion/src/dispatcher.rs b/src/bastion/src/dispatcher.rs index fcb447d9..247d3dbc 100644 --- a/src/bastion/src/dispatcher.rs +++ b/src/bastion/src/dispatcher.rs @@ -9,8 +9,8 @@ use crate::{ }; use crate::{distributor::Distributor, envelope::SignedMessage}; use anyhow::Result as AnyResult; +use futures::Future; use lever::prelude::*; -use std::hash::{Hash, Hasher}; use std::sync::RwLock; use std::sync::{ atomic::{AtomicUsize, Ordering}, @@ -20,6 +20,10 @@ use std::{ collections::HashMap, fmt::{self, Debug}, }; +use std::{ + hash::{Hash, Hasher}, + task::Poll, +}; use tracing::{debug, trace}; /// Type alias for the concurrency hashmap. Each key-value pair stores @@ -71,6 +75,29 @@ pub trait Recipient { /// A `RecipientHandler` is a `Recipient` implementor, that can be stored in the dispatcher pub trait RecipientHandler: Recipient + Send + Sync + Debug {} +impl Future for RoundRobinHandler { + type Output = Vec; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let recipients = self.public_recipients(); + if !recipients.is_empty() { + return Poll::Ready(recipients); + } + + self.waker.register(cx.waker()); + + let recipients = self.public_recipients(); + if !recipients.is_empty() { + Poll::Ready(recipients) + } else { + Poll::Pending + } + } +} + impl RecipientHandler for RoundRobinHandler {} /// The default handler, which does round-robin. @@ -101,6 +128,7 @@ pub type DefaultDispatcherHandler = RoundRobinHandler; pub struct RoundRobinHandler { index: AtomicUsize, recipients: RecipientMap, + waker: futures::task::AtomicWaker, } impl RoundRobinHandler { @@ -118,6 +146,22 @@ impl RoundRobinHandler { } } +impl RoundRobinHandler { + async fn poll_next(&mut self) -> ChildRef { + let index = self.index.fetch_add(1, Ordering::SeqCst); + let recipients = self.await; + // TODO [igni]: unwrap?! + recipients + .get(index % recipients.len()) + .map(std::clone::Clone::clone) + .unwrap() + } + + async fn poll_all(&mut self) -> Vec { + self.await + } +} + impl Recipient for RoundRobinHandler { fn next(&self) -> Option { let entries = self.public_recipients(); @@ -137,6 +181,7 @@ impl Recipient for RoundRobinHandler { fn register(&self, actor: ChildRef) { let _ = self.recipients.insert(actor, ()); + self.waker.wake(); } fn remove(&self, actor: &ChildRef) {