diff --git a/tokio-util/src/task/spawn_pinned.rs b/tokio-util/src/task/spawn_pinned.rs index 3f692d3cf2c..31efafa1673 100644 --- a/tokio-util/src/task/spawn_pinned.rs +++ b/tokio-util/src/task/spawn_pinned.rs @@ -4,19 +4,18 @@ use std::fmt::{Debug, Formatter}; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; -use tokio::task::{spawn_local, JoinHandle, LocalSet}; +use tokio::task::{spawn_local, JoinHandle}; /// A cloneable handle to a local pool, used for spawning `!Send` tasks. /// -/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread +/// Internally the local pool uses a [`tokio::runtime::LocalRuntime`] for each worker thread /// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will /// execute on the same thread) inside the Future you supply to the various spawn methods /// of `LocalPoolHandle`. /// -/// [`tokio::task::LocalSet`]: tokio::task::LocalSet +/// [`tokio::runtime::LocalRuntime`]: tokio::runtime::LocalRuntime /// [`tokio::task::spawn_local`]: tokio::task::spawn_local /// /// # Examples @@ -238,10 +237,10 @@ impl LocalPool { let _abort_guard = AbortGuard(abort_handle); // Inside the future we can't run spawn_local yet because we're not - // in the context of a LocalSet. We need to send create_task to the - // LocalSet task for spawning. + // in the context of a LocalRuntime. We need to send create_task to the + // LocalRuntime task for spawning. let spawn_task = Box::new(move || { - // Once we're in the LocalSet context we can call spawn_local + // Once we're in the LocalRuntime context we can call spawn_local let join_handle = spawn_local( async move { Abortable::new(create_task(), abort_registration).await }, @@ -255,7 +254,7 @@ impl LocalPool { } }); - // Send the callback to the LocalSet task + // Send the callback to the LocalRuntime task if let Err(e) = worker_spawner.send(spawn_task) { // Propagate the error as a panic in the join handle. panic!("Failed to send job to worker: {e}"); @@ -379,15 +378,17 @@ impl LocalWorkerHandle { /// Create a new worker for executing pinned tasks fn new_worker() -> LocalWorkerHandle { let (sender, receiver) = unbounded_channel(); - let runtime = Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to start a pinned worker thread runtime"); - let runtime_handle = runtime.handle().clone(); + let (handle_sender, handle_receiver) = std::sync::mpsc::channel(); + let task_count = Arc::new(AtomicUsize::new(0)); let task_count_clone = Arc::clone(&task_count); - std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + std::thread::spawn(|| Self::run(handle_sender, receiver, task_count_clone)); + + let runtime_handle = handle_receiver + .recv() + .expect("Failed to recv local runtime init result") + .expect("Failed to start local runtime"); LocalWorkerHandle { runtime_handle, @@ -397,28 +398,37 @@ impl LocalWorkerHandle { } fn run( - runtime: tokio::runtime::Runtime, + handle_sender: std::sync::mpsc::Sender>, mut task_receiver: UnboundedReceiver, task_count: Arc, ) { - let local_set = LocalSet::new(); - local_set.block_on(&runtime, async { + let runtime = match tokio::runtime::LocalRuntime::new() { + Ok(runtime) => runtime, + Err(err) => { + let _ = handle_sender.send(Err(err)); + return; + } + }; + + let runtime_handle = runtime.handle().clone(); + + handle_sender + .send(Ok(runtime_handle)) + .expect("Failed to send local runtime handle"); + + runtime.block_on(async { while let Some(spawn_task) = task_receiver.recv().await { // Calls spawn_local(future) (spawn_task)(); } }); - // If there are any tasks on the runtime associated with a LocalSet task - // that has already completed, but whose output has not yet been - // reported, let that task complete. + // If there are any tasks on the runtime that has already completed, + // but whose output has not yet been reported, let that task complete. // // Since the task_count is decremented when the runtime task exits, // reading that counter lets us know if any such tasks completed during // the call to `block_on`. - // - // Tasks on the LocalSet can't complete during this loop since they're - // stored on the LocalSet and we aren't accessing it. let mut previous_task_count = task_count.load(Ordering::SeqCst); loop { // This call will also run tasks spawned on the runtime. @@ -431,15 +441,10 @@ impl LocalWorkerHandle { } } - // It's now no longer possible for a task on the runtime to be - // associated with a LocalSet task that has completed. Drop both the - // LocalSet and runtime to let tasks on the runtime be cancelled if and - // only if they are still on the LocalSet. - // - // Drop the LocalSet task first so that anyone awaiting the runtime - // JoinHandle will see the cancelled error after the LocalSet task - // destructor has completed. - drop(local_set); + // It's now no longer possible for a task on the local runtime + // associated with task that has completed. Drop both + // local runtime to let tasks on the runtime be cancelled if and + // only if they are still on the runtime. drop(runtime); } }