diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 3ab9a2ca44c..37c76d715ff 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -366,8 +366,8 @@ struct Tail { /// Number of active receivers. rx_cnt: usize, - /// True if the channel is closed. - closed: bool, + /// True if there are any strong senders. + has_senders: bool, /// Receivers waiting for a value. waiters: LinkedList::Target>, @@ -566,7 +566,7 @@ impl Sender { tail: Mutex::new(Tail { pos: 0, rx_cnt: receiver_count, - closed: receiver_count == 0, + has_senders: true, waiters: LinkedList::new(), }), num_tx: AtomicUsize::new(1), @@ -887,24 +887,12 @@ impl Sender { /// # } /// ``` pub async fn closed(&self) { - loop { - let notified = self.shared.notify_last_rx_drop.notified(); - - { - // Ensure the lock drops if the channel isn't closed - let tail = self.shared.tail.lock(); - if tail.closed { - return; - } - } - - notified.await; - } + self.shared.closed_for_senders().await; } fn close_channel(&self) { let mut tail = self.shared.tail.lock(); - tail.closed = true; + tail.has_senders = false; self.shared.notify_rx(tail); } @@ -926,13 +914,6 @@ fn new_receiver(shared: Arc>) -> Receiver { assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers"); - if tail.rx_cnt == 0 { - // Potentially need to re-open the channel, if a new receiver has been added between calls - // to poll(). Note that we use rx_cnt == 0 instead of is_closed since is_closed also - // applies if the sender has been dropped - tail.closed = false; - } - tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); let next = tail.pos; @@ -1053,6 +1034,27 @@ impl Shared { wakers.wake_all(); } + + async fn closed_for_senders(&self) { + cooperative(async { + crate::trace::async_trace_leaf().await; + + loop { + let notified = self.notify_last_rx_drop.notified(); + + { + // Ensure the lock drops if the channel isn't closed + let tail = self.tail.lock(); + if tail.rx_cnt == 0 { + return; + } + } + + notified.await; + } + }) + .await; + } } impl Clone for Sender { @@ -1102,6 +1104,37 @@ impl WeakSender { } } + /// A future which completes when the number of [Receiver]s subscribed to this channel reaches + /// zero, regardless of whether strong senders still exist. + /// + /// # Examples + /// + /// ``` + /// use futures::FutureExt; + /// use tokio::sync::broadcast; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let (tx, mut rx1) = broadcast::channel::(16); + /// let mut rx2 = tx.subscribe(); + /// + /// let _ = tx.send(10); + /// let weak = tx.downgrade(); + /// drop(tx); + /// + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// drop(rx1); + /// assert!(weak.closed().now_or_never().is_none()); + /// + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// drop(rx2); + /// assert!(weak.closed().now_or_never().is_some()); + /// # } + /// ``` + pub async fn closed(&self) { + self.shared.closed_for_senders().await; + } + /// Returns the number of [`Sender`] handles. pub fn strong_count(&self) -> usize { self.shared.num_tx.load(Acquire) @@ -1256,7 +1289,7 @@ impl Receiver { // At this point the channel is empty for *this* receiver. If // it's been closed, then that's what we return, otherwise we // set a waker and return empty. - if tail.closed { + if !tail.has_senders { return Err(TryRecvError::Closed); } @@ -1555,7 +1588,6 @@ impl Drop for Receiver { if remaining_rx == 0 { self.shared.notify_last_rx_drop.notify_waiters(); - tail.closed = true; } drop(tail); diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index b742a6f6161..c81008c60a5 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -673,38 +673,72 @@ fn broadcast_sender_closed() { assert_ready!(task.poll()); } +#[test] +fn broadcast_weak_sender_closed() { + let (tx, rx) = broadcast::channel::<()>(1); + let rx2 = tx.subscribe(); + let weak = tx.downgrade(); + drop(tx); + + let mut task = task::spawn(weak.closed()); + assert_pending!(task.poll()); + + drop(rx); + assert!(!task.is_woken()); + assert_pending!(task.poll()); + + drop(rx2); + assert!(task.is_woken()); + assert_ready!(task.poll()); +} + #[test] fn broadcast_sender_closed_with_extra_subscribe() { let (tx, rx) = broadcast::channel::<()>(1); let rx2 = tx.subscribe(); + let weak = tx.downgrade(); let mut task = task::spawn(tx.closed()); + let mut weak_task = task::spawn(weak.closed()); assert_pending!(task.poll()); + assert_pending!(weak_task.poll()); drop(rx); assert!(!task.is_woken()); + assert!(!weak_task.is_woken()); assert_pending!(task.poll()); + assert_pending!(weak_task.poll()); drop(rx2); assert!(task.is_woken()); + assert!(weak_task.is_woken()); let rx3 = tx.subscribe(); assert_pending!(task.poll()); + assert_pending!(weak_task.poll()); drop(rx3); assert!(task.is_woken()); + assert!(weak_task.is_woken()); assert_ready!(task.poll()); + assert_ready!(weak_task.poll()); let mut task2 = task::spawn(tx.closed()); assert_ready!(task2.poll()); + let mut weak_task2 = task::spawn(weak.closed()); + assert_ready!(weak_task2.poll()); let rx4 = tx.subscribe(); let mut task3 = task::spawn(tx.closed()); assert_pending!(task3.poll()); + let mut weak_task3 = task::spawn(weak.closed()); + assert_pending!(weak_task3.poll()); drop(rx4); assert!(task3.is_woken()); assert_ready!(task3.poll()); + assert!(weak_task3.is_woken()); + assert_ready!(weak_task3.poll()); } #[tokio::test]