Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion crates/bevy_tasks/src/futures.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! Utilities for working with [`Future`]s.
use alloc::task::Wake;
use bevy_platform::sync::Arc;
use core::{
future::Future,
pin::pin,
pin::{pin, Pin},
task::{Context, Poll, Waker},
};

Expand All @@ -22,3 +24,66 @@ pub fn now_or_never<F: Future>(future: F) -> Option<F::Output> {
pub fn check_ready<F: Future + Unpin>(future: &mut F) -> Option<F::Output> {
now_or_never(future)
}

/// Wraps a future such that the Waker given to the future also runs the "kicker".
///
/// This allows us to trigger an action (the "kicker") in addition to just waking the future. The
/// kicker is also triggered when the future resolves (i.e., returns [`Poll::Ready`]).
pub(crate) struct KickOnWake<F> {
/// The "kicker" that will be invoked when the future wakes up or resolves.
pub(crate) kicker: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
/// The inner future.
pub(crate) f: F,
}

impl<F: Future> Future for KickOnWake<F> {
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(kicker) = self.kicker.clone() else {
#[expect(
unsafe_code,
reason = "We need to manually pin so we can support wrapping any future."
)]
// SAFETY: We don't move out of `this` inside the closure, and we don't move out of `f`
// in any case - we assume that pinning `self` also means pinning `self.f`.
return unsafe { self.map_unchecked_mut(|this| &mut this.f) }.poll(cx);
};
let wrapped_waker = Waker::from(Arc::new(KickThenWake {
kicker,
waker: cx.waker().clone(),
}));
let mut cx = Context::from_waker(&wrapped_waker);
#[expect(
unsafe_code,
reason = "We need to manually pin so we can support wrapping any future."
)]
// SAFETY: We don't move out of `this` inside the closure, and we don't move out of `f`
// in any case - we assume that pinning `self` also means pinning `self.f`.
let result = unsafe { self.map_unchecked_mut(|this| &mut this.f) }.poll(&mut cx);
// Also kick if the future resolves.
if result.is_ready() {
wrapped_waker.wake_by_ref();
}
result
}
}

/// A waker that wraps another waker, but first executing the "kicker".
struct KickThenWake {
/// The "kicker" that will be invoked when the future wakes up or resolves.
kicker: Arc<dyn Fn() + Send + Sync + 'static>,
/// The actual waker to invoke after the kicker.
waker: Waker,
}

impl Wake for KickThenWake {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}

fn wake_by_ref(self: &Arc<Self>) {
(*self.kicker)();
self.waker.wake_by_ref();
}
}
82 changes: 66 additions & 16 deletions crates/bevy_tasks/src/single_threaded_task_pool.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use alloc::{string::String, vec::Vec};
use bevy_platform::sync::Arc;
use core::{cell::{RefCell, Cell}, future::Future, marker::PhantomData, mem};
use bevy_platform::sync::{Arc, PoisonError, RwLock};
use core::{
cell::{Cell, RefCell},
future::Future,
marker::PhantomData,
mem,
};

use crate::executor::LocalExecutor;
use crate::{block_on, Task};
use crate::{executor::LocalExecutor, futures::KickOnWake};

crate::cfg::std! {
if {
Expand Down Expand Up @@ -80,8 +85,16 @@ impl TaskPoolBuilder {

/// A thread pool for executing tasks. Tasks are futures that are being automatically driven by
/// the pool on threads owned by the pool. In this case - main thread only.
#[derive(Debug, Default, Clone)]
pub struct TaskPool {}
#[derive(Default)]
pub struct TaskPool {
kicker: RwLock<Option<Arc<dyn Fn() + Send + Sync + 'static>>>,
}

impl core::fmt::Debug for TaskPool {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TaskPool").finish()
}
}

impl TaskPool {
/// Just create a new `ThreadExecutor` for wasm
Expand All @@ -95,7 +108,18 @@ impl TaskPool {
}

fn new_internal() -> Self {
Self {}
Self {
kicker: Default::default(),
}
}

/// Sets the "kicker" that futures will invoke when waking.
///
/// This allows event loops to be notified whenever a future resolves. Note changing this at
/// runtime can have **unpredictable results**. Users should set this before spawning any
/// futures to ensure the kicker is invoked.
pub fn set_kicker(&self, kicker: Arc<dyn Fn() + Send + Sync + 'static>) {
*self.kicker.write().unwrap_or_else(PoisonError::into_inner) = Some(kicker);
}

/// Return the number of threads owned by the task pool
Expand Down Expand Up @@ -156,6 +180,11 @@ impl TaskPool {
executor_ref,
pending_tasks,
results_ref,
kicker: self
.kicker
.read()
.unwrap_or_else(PoisonError::into_inner)
.clone(),
scope: PhantomData,
env: PhantomData,
};
Expand Down Expand Up @@ -192,20 +221,25 @@ impl TaskPool {
where
T: 'static + MaybeSend + MaybeSync,
{
let kicker = self
.kicker
.read()
.unwrap_or_else(PoisonError::into_inner)
.clone();
crate::cfg::switch! {{
crate::cfg::web => {
web_task::spawn_local(future)
web_task::spawn_local(KickOnWake { kicker, f: future })
}
crate::cfg::std => {
LOCAL_EXECUTOR.with(|executor| {
let task = executor.spawn(future);
let task = executor.spawn(KickOnWake { kicker, f: future });
// Loop until all tasks are done
while executor.try_tick() {}
task
})
}
_ => {
let task = LOCAL_EXECUTOR.spawn(future);
let task = LOCAL_EXECUTOR.spawn(KickOnWake { kicker, f: future });
// Loop until all tasks are done
while LOCAL_EXECUTOR.try_tick() {}
task
Expand Down Expand Up @@ -253,19 +287,32 @@ impl TaskPool {
/// A `TaskPool` scope for running one or more non-`'static` futures.
///
/// For more information, see [`TaskPool::scope`].
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope, T> {
executor_ref: &'scope LocalExecutor<'scope>,
// The number of pending tasks spawned on the scope
pending_tasks: &'scope Cell<usize>,
// Vector to gather results of all futures spawned during scope run
results_ref: &'env RefCell<Vec<Option<T>>>,
/// The kicker to wake whenever a future wakes.
kicker: Option<Arc<dyn Fn() + Send + Sync + 'static>>,

// make `Scope` invariant over 'scope and 'env
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

impl<'scope, 'env: 'scope, T: core::fmt::Debug> core::fmt::Debug for Scope<'scope, 'env, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Scope")
.field("executor_ref", &self.executor_ref)
.field("pending_tasks", &self.pending_tasks)
.field("results_ref", &self.results_ref)
.field("scope", &self.scope)
.field("env", &self.env)
.finish()
}
}

impl<'scope, 'env, T: Send + 'env> Scope<'scope, 'env, T> {
/// Spawns a scoped future onto the executor. The scope *must* outlive
/// the provided future. The results of the future will be returned as a part of
Expand Down Expand Up @@ -320,7 +367,12 @@ impl<'scope, 'env, T: Send + 'env> Scope<'scope, 'env, T> {
};

// spawn the job itself
self.executor_ref.spawn(f).detach();
self.executor_ref
.spawn(KickOnWake {
kicker: self.kicker.clone(),
f,
})
.detach();
}
}

Expand All @@ -342,7 +394,7 @@ crate::cfg::std! {

#[cfg(test)]
mod test {
use std::{time, thread};
use std::{thread, time};

use super::*;

Expand All @@ -355,16 +407,14 @@ mod test {
#[test]
fn scoped_spawn() {
let (sender, receiver) = async_channel::unbounded();
let task_pool = TaskPool {};
let task_pool = TaskPool::new();
let thread = thread::spawn(move || {
let duration = time::Duration::from_millis(50);
thread::sleep(duration);
let _ = sender.send(0);
});
task_pool.scope(|scope| {
scope.spawn(async {
receiver.recv().await
});
scope.spawn(async { receiver.recv().await });
});
}
}
Loading
Loading