Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 113 additions & 90 deletions turbopack/crates/turbo-tasks-backend/src/backend/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ impl Storage {
if modified_count == 0 {
return None;
}
let mut direct_snapshots: Vec<(TaskId, Box<TaskStorage>)> = Vec::new();
let mut modified = Vec::with_capacity(modified_count as usize);
{
let shard_guard = shard.read();
Expand All @@ -229,44 +228,26 @@ impl Storage {
// accompanied by modified flags (set_persistent_task_type calls
// track_modification), so any_modified() is sufficient.
if flags.any_modified() {
debug_assert!(
!key.is_transient(),
"found a modified transient task: {:?}",
shared_value.get().get_persistent_task_type()
);

if flags.any_modified_during_snapshot() {
// Task was modified during snapshot mode, so a snapshot
// copy must exist in the snapshots map (created by the
// (true, true) case in track_modification_internal).
// Remove the entry entirely so end_snapshot doesn't
// double-process this task. When iterating in `next` we will
// re-synchronize the task flags.
let (_, snapshot) = self.snapshots.remove(key).expect(
"task with modified_during_snapshot must have a snapshots entry",
if key.is_transient() {
debug_assert!(
false,
"found a modified transient task: {:?}",
shared_value.get().get_persistent_task_type()
);
let snapshot = snapshot.expect(
"snapshot entry for modified_during_snapshot task must contain a \
value",
);
direct_snapshots.push((*key, snapshot));
} else {
modified.push(*key);
continue;
}

modified.push(*key);
}
}
// Safety: shard_guard must outlive the iterator.
drop(shard_guard);
}

// Early return for shards with no entries at all
if direct_snapshots.is_empty() && modified.is_empty() {
return None;
}
debug_assert!(!modified.is_empty());

Some(SnapshotShard {
shard_idx,
direct_snapshots,
modified,
storage: self,
process,
Expand Down Expand Up @@ -568,7 +549,6 @@ impl Drop for SnapshotGuard<'_> {

pub struct SnapshotShard<'l, P> {
shard_idx: usize,
direct_snapshots: Vec<(TaskId, Box<TaskStorage>)>,
modified: Vec<TaskId>,
storage: &'l Storage,
process: &'l P,
Expand Down Expand Up @@ -606,16 +586,27 @@ where
type Item = SnapshotItem;

fn next(&mut self) -> Option<Self::Item> {
// direct_snapshots: these tasks had a snapshot copy created by
// track_modification. We encode from the owned snapshot copy,
// clear the stale modified flags, and promote any _during_snapshot
// flags so the task stays dirty for the next cycle.
if let Some((task_id, snapshot)) = self.shard.direct_snapshots.pop() {
let item = (self.shard.process)(task_id, &snapshot, &mut self.buffer);
// Clear pre-snapshot flags. Since we removed this task's entry from the
// snapshots map in take_snapshot, end_snapshot won't see it, so we must
// promote here.
if let Some(task_id) = self.shard.modified.pop() {
let mut inner = self.shard.storage.map.get_mut(&task_id).unwrap();
// If the task was re-modified during snapshot, the snapshots map may
// hold a pre-modification copy we must serialize instead of the live
// data. Remove the entry so end_snapshot doesn't double-promote it;
// we promote manually below.
let item = if inner.flags.any_modified_during_snapshot() {
match self.shard.storage.snapshots.remove(&task_id) {
Some((_, Some(snapshot))) => {
(self.shard.process)(task_id, &snapshot, &mut self.buffer)
}
Some((_, None)) | None => {
(self.shard.process)(task_id, &inner, &mut self.buffer)
}
}
} else {
(self.shard.process)(task_id, &inner, &mut self.buffer)
};
// Clear the modified flags that were captured into the snapshot copy,
// then promote modified_during_snapshot → modified so the task stays
// dirty for the next snapshot cycle.
inner.flags.set_data_modified(false);
inner.flags.set_meta_modified(false);
inner.flags.set_new_task(false);
Expand All @@ -624,45 +615,6 @@ where
.promote_during_snapshot_flags(&mut inner, self.shard.shard_idx);
return Some(item);
}
// modified tasks: acquire a write lock to encode and clear flags in one pass.
if let Some(task_id) = self.shard.modified.pop() {
let mut inner = self.shard.storage.map.get_mut(&task_id).unwrap();
if !inner.flags.any_modified_during_snapshot() {
let item = (self.shard.process)(task_id, &inner, &mut self.buffer);
inner.flags.set_data_modified(false);
inner.flags.set_meta_modified(false);
inner.flags.set_new_task(false);
return Some(item);
} else {
// Task was modified again during snapshot mode. A snapshot copy was
// created in track_modification_internal. Remove it and encode it.
// end_snapshot must not also process it, so we take it out of the map.
// snapshots is a separate DashMap from map, so holding `inner` across
// the remove and encode is safe — no lock ordering issue.
let snapshot = self
.shard
.storage
.snapshots
.remove(&task_id)
.expect("The snapshot bit was set, so it must be in Snapshot state")
.1
.expect(
"snapshot entry for modified_during_snapshot task must contain a value",
);

let item = (self.shard.process)(task_id, &snapshot, &mut self.buffer);
// Clear the modified flags that were captured into the snapshot copy,
// then promote modified_during_snapshot → modified so the task stays
// dirty for the next snapshot cycle.
inner.flags.set_data_modified(false);
inner.flags.set_meta_modified(false);
inner.flags.set_new_task(false);
self.shard
.storage
.promote_during_snapshot_flags(&mut inner, self.shard.shard_idx);
return Some(item);
}
}
None
}
}
Expand Down Expand Up @@ -704,20 +656,22 @@ mod tests {
}

/// Regression test: a task modified before a snapshot and then modified *again* during
/// snapshot iteration must not trigger `debug_assert!(!inner.flags.any_modified())` in
/// `SnapshotShardIter::next`.
/// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
/// modification forward to the next cycle.
///
/// Sequence of events:
/// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
/// 2. `start_snapshot` puts us in snapshot mode.
/// 3. `take_snapshot` scans the shard: task has `any_modified()=true` and
/// `any_modified_during_snapshot()=false` → task goes into the `modified` list.
/// 4. **Between scan and iteration**: `track_modification` is called on the task again. This is
/// the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of the
/// pre-snapshot state is created (carrying the modified bits) and stored in `snapshots`.
/// 5. `SnapshotShardIter::next` processes the task from the `modified` list, finds
/// `any_modified_during_snapshot()=true`, clears the live modified flags (which were
/// captured into the snapshot), then asserts `!any_modified()` before promoting.
/// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
/// `modified` list.
/// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
/// is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
/// the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
/// `data_modified_during_snapshot` is set.
/// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
/// `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
/// pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
/// promotes `data_modified_during_snapshot → data_modified` for the next cycle.
// `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
// requiring a multi-threaded Tokio runtime.
#[tokio::test(flavor = "multi_thread")]
Expand Down Expand Up @@ -751,8 +705,8 @@ mod tests {
assert!(guard.flags.data_modified_during_snapshot())
}

// Step 5: consume the iterator. The iterator clears the live modified flags
// before the assert, encodes the snapshot copy, and promotes
// Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
// clears the live modified flags, removes the snapshots entry, and promotes
// `data_modified_during_snapshot → data_modified` for the next cycle.
let items: Vec<_> = shards
.into_iter()
Expand All @@ -765,7 +719,7 @@ mod tests {

{
let guard = storage.access_mut(task_id);
// Ending the snapshot should have promoted modified_during_snapshot → modified.
// The iterator should have promoted modified_during_snapshot → modified.
assert!(guard.flags.data_modified());
}

Expand All @@ -777,4 +731,73 @@ mod tests {
"shard_modified_counts must be non-zero after promoting modified_during_snapshot"
);
}

/// Regression test for the `(true, false)` during-snapshot case: a task modified in one
/// category before a snapshot, then modified in a *different* category during snapshot
/// iteration, must not panic and must carry both modifications forward correctly.
///
/// Sequence of events:
/// 1. Task meta is modified (meta_modified = true).
/// 2. `start_snapshot` puts us in snapshot mode.
/// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
/// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
/// modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
/// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
/// sees `None` in snapshots, encodes from live data (correct — live data for the
/// unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
/// flags, and promotes `data_modified_during_snapshot → data_modified`.
#[tokio::test(flavor = "multi_thread")]
async fn modify_different_category_during_snapshot() {
let storage = Storage::new(2, true);
let task_id = non_transient_task(1);

// Step 1: modify meta only, outside snapshot mode.
{
let mut guard = storage.access_mut(task_id);
guard.track_modification(SpecificTaskDataCategory::Meta, "test");
assert!(guard.flags.meta_modified());
assert!(!guard.flags.data_modified());
}

// Step 2: enter snapshot mode.
let (snapshot_guard, has_modifications) = storage.start_snapshot();
assert!(has_modifications);

// Step 3: take_snapshot — task goes into modified list (meta_modified = true).
let shards = storage.take_snapshot(snapshot_guard, &dummy_process);

// Step 4: modify data during snapshot. The `(true, false)` branch fires:
// data was not previously modified, so snapshots gets a None entry.
{
let mut guard = storage.access_mut(task_id);
guard.track_modification(SpecificTaskDataCategory::Data, "test");
assert!(guard.flags.data_modified_during_snapshot());
assert!(!guard.flags.meta_modified_during_snapshot());
}

// Step 5: consume the iterator — must not panic.
let items: Vec<_> = shards
.into_iter()
.flat_map(|shard| shard.into_iter())
.collect();

assert_eq!(items.len(), 1);
assert_eq!(items[0].task_id, task_id);

{
let guard = storage.access_mut(task_id);
// meta_modified was cleared by the iterator (it was the pre-snapshot flag).
assert!(!guard.flags.meta_modified());
// data_modified_during_snapshot was promoted to data_modified.
assert!(guard.flags.data_modified());
assert!(!guard.flags.data_modified_during_snapshot());
}

// Next snapshot cycle must pick up the promoted data_modified.
let (_guard2, has_modifications) = storage.start_snapshot();
assert!(
has_modifications,
"shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
);
}
}
Loading