Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::runtime::Builder;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};
use tokio::task::{spawn_local, JoinHandle};

/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
/// Internally the local pool uses a [`tokio::runtime::LocalRuntime`] for each worker thread
/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
/// execute on the same thread) inside the Future you supply to the various spawn methods
/// of `LocalPoolHandle`.
///
/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
/// [`tokio::runtime::LocalRuntime`]: tokio::runtime::LocalRuntime
/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
///
/// # Examples
Expand Down Expand Up @@ -238,10 +237,10 @@ impl LocalPool {
let _abort_guard = AbortGuard(abort_handle);

// Inside the future we can't run spawn_local yet because we're not
// in the context of a LocalSet. We need to send create_task to the
// LocalSet task for spawning.
// in the context of a LocalRuntime. We need to send create_task to the
// LocalRuntime task for spawning.
let spawn_task = Box::new(move || {
// Once we're in the LocalSet context we can call spawn_local
// Once we're in the LocalRuntime context we can call spawn_local
let join_handle =
spawn_local(
async move { Abortable::new(create_task(), abort_registration).await },
Expand All @@ -255,7 +254,7 @@ impl LocalPool {
}
});

// Send the callback to the LocalSet task
// Send the callback to the LocalRuntime task
if let Err(e) = worker_spawner.send(spawn_task) {
// Propagate the error as a panic in the join handle.
panic!("Failed to send job to worker: {e}");
Expand Down Expand Up @@ -379,15 +378,17 @@ impl LocalWorkerHandle {
/// Create a new worker for executing pinned tasks
fn new_worker() -> LocalWorkerHandle {
let (sender, receiver) = unbounded_channel();
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to start a pinned worker thread runtime");
let runtime_handle = runtime.handle().clone();
let (handle_sender, handle_receiver) = std::sync::mpsc::channel();

let task_count = Arc::new(AtomicUsize::new(0));
let task_count_clone = Arc::clone(&task_count);

std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
std::thread::spawn(|| Self::run(handle_sender, receiver, task_count_clone));

let runtime_handle = handle_receiver
.recv()
.expect("Failed to recv local runtime init result")
.expect("Failed to start local runtime");

LocalWorkerHandle {
runtime_handle,
Expand All @@ -397,28 +398,37 @@ impl LocalWorkerHandle {
}

fn run(
runtime: tokio::runtime::Runtime,
handle_sender: std::sync::mpsc::Sender<std::io::Result<tokio::runtime::Handle>>,
mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
) {
let local_set = LocalSet::new();
local_set.block_on(&runtime, async {
let runtime = match tokio::runtime::LocalRuntime::new() {
Ok(runtime) => runtime,
Err(err) => {
let _ = handle_sender.send(Err(err));
return;
}
};

let runtime_handle = runtime.handle().clone();

handle_sender
.send(Ok(runtime_handle))
.expect("Failed to send local runtime handle");

runtime.block_on(async {
while let Some(spawn_task) = task_receiver.recv().await {
// Calls spawn_local(future)
(spawn_task)();
}
});

// If there are any tasks on the runtime associated with a LocalSet task
// that has already completed, but whose output has not yet been
// reported, let that task complete.
// If there are any tasks on the runtime that has already completed,
// but whose output has not yet been reported, let that task complete.
//
// Since the task_count is decremented when the runtime task exits,
// reading that counter lets us know if any such tasks completed during
// the call to `block_on`.
//
// Tasks on the LocalSet can't complete during this loop since they're
// stored on the LocalSet and we aren't accessing it.
let mut previous_task_count = task_count.load(Ordering::SeqCst);
loop {
// This call will also run tasks spawned on the runtime.
Expand All @@ -431,15 +441,10 @@ impl LocalWorkerHandle {
}
}

// It's now no longer possible for a task on the runtime to be
// associated with a LocalSet task that has completed. Drop both the
// LocalSet and runtime to let tasks on the runtime be cancelled if and
// only if they are still on the LocalSet.
//
// Drop the LocalSet task first so that anyone awaiting the runtime
// JoinHandle will see the cancelled error after the LocalSet task
// destructor has completed.
drop(local_set);
// It's now no longer possible for a task on the local runtime
// associated with task that has completed. Drop both
// local runtime to let tasks on the runtime be cancelled if and
// only if they are still on the runtime.
drop(runtime);
}
}