@@ -213,7 +213,6 @@ impl Storage {
213213 if modified_count == 0 {
214214 return None ;
215215 }
216- let mut direct_snapshots: Vec < ( TaskId , Box < TaskStorage > ) > = Vec :: new ( ) ;
217216 let mut modified = Vec :: with_capacity ( modified_count as usize ) ;
218217 {
219218 let shard_guard = shard. read ( ) ;
@@ -229,44 +228,26 @@ impl Storage {
229228 // accompanied by modified flags (set_persistent_task_type calls
230229 // track_modification), so any_modified() is sufficient.
231230 if flags. any_modified ( ) {
232- debug_assert ! (
233- !key. is_transient( ) ,
234- "found a modified transient task: {:?}" ,
235- shared_value. get( ) . get_persistent_task_type( )
236- ) ;
237-
238- if flags. any_modified_during_snapshot ( ) {
239- // Task was modified during snapshot mode, so a snapshot
240- // copy must exist in the snapshots map (created by the
241- // (true, true) case in track_modification_internal).
242- // Remove the entry entirely so end_snapshot doesn't
243- // double-process this task. When iterating in `next` we will
244- // re-synchronize the task flags.
245- let ( _, snapshot) = self . snapshots . remove ( key) . expect (
246- "task with modified_during_snapshot must have a snapshots entry" ,
231+ if key. is_transient ( ) {
232+ debug_assert ! (
233+ false ,
234+ "found a modified transient task: {:?}" ,
235+ shared_value. get( ) . get_persistent_task_type( )
247236 ) ;
248- let snapshot = snapshot. expect (
249- "snapshot entry for modified_during_snapshot task must contain a \
250- value",
251- ) ;
252- direct_snapshots. push ( ( * key, snapshot) ) ;
253- } else {
254- modified. push ( * key) ;
237+ continue ;
255238 }
239+
240+ modified. push ( * key) ;
256241 }
257242 }
258243 // Safety: shard_guard must outlive the iterator.
259244 drop ( shard_guard) ;
260245 }
261246
262- // Early return for shards with no entries at all
263- if direct_snapshots. is_empty ( ) && modified. is_empty ( ) {
264- return None ;
265- }
247+ debug_assert ! ( !modified. is_empty( ) ) ;
266248
267249 Some ( SnapshotShard {
268250 shard_idx,
269- direct_snapshots,
270251 modified,
271252 storage : self ,
272253 process,
@@ -568,7 +549,6 @@ impl Drop for SnapshotGuard<'_> {
568549
569550pub struct SnapshotShard < ' l , P > {
570551 shard_idx : usize ,
571- direct_snapshots : Vec < ( TaskId , Box < TaskStorage > ) > ,
572552 modified : Vec < TaskId > ,
573553 storage : & ' l Storage ,
574554 process : & ' l P ,
@@ -606,16 +586,27 @@ where
606586 type Item = SnapshotItem ;
607587
608588 fn next ( & mut self ) -> Option < Self :: Item > {
609- // direct_snapshots: these tasks had a snapshot copy created by
610- // track_modification. We encode from the owned snapshot copy,
611- // clear the stale modified flags, and promote any _during_snapshot
612- // flags so the task stays dirty for the next cycle.
613- if let Some ( ( task_id, snapshot) ) = self . shard . direct_snapshots . pop ( ) {
614- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
615- // Clear pre-snapshot flags. Since we removed this task's entry from the
616- // snapshots map in take_snapshot, end_snapshot won't see it, so we must
617- // promote here.
589+ if let Some ( task_id) = self . shard . modified . pop ( ) {
618590 let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
591+ // If the task was re-modified during snapshot, the snapshots map may
592+ // hold a pre-modification copy we must serialize instead of the live
593+ // data. Remove the entry so end_snapshot doesn't double-promote it;
594+ // we promote manually below.
595+ let item = if inner. flags . any_modified_during_snapshot ( ) {
596+ match self . shard . storage . snapshots . remove ( & task_id) {
597+ Some ( ( _, Some ( snapshot) ) ) => {
598+ ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer )
599+ }
600+ Some ( ( _, None ) ) | None => {
601+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
602+ }
603+ }
604+ } else {
605+ ( self . shard . process ) ( task_id, & inner, & mut self . buffer )
606+ } ;
607+ // Clear the modified flags that were captured into the snapshot copy,
608+ // then promote modified_during_snapshot → modified so the task stays
609+ // dirty for the next snapshot cycle.
619610 inner. flags . set_data_modified ( false ) ;
620611 inner. flags . set_meta_modified ( false ) ;
621612 inner. flags . set_new_task ( false ) ;
@@ -624,45 +615,6 @@ where
624615 . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
625616 return Some ( item) ;
626617 }
627- // modified tasks: acquire a write lock to encode and clear flags in one pass.
628- if let Some ( task_id) = self . shard . modified . pop ( ) {
629- let mut inner = self . shard . storage . map . get_mut ( & task_id) . unwrap ( ) ;
630- if !inner. flags . any_modified_during_snapshot ( ) {
631- let item = ( self . shard . process ) ( task_id, & inner, & mut self . buffer ) ;
632- inner. flags . set_data_modified ( false ) ;
633- inner. flags . set_meta_modified ( false ) ;
634- inner. flags . set_new_task ( false ) ;
635- return Some ( item) ;
636- } else {
637- // Task was modified again during snapshot mode. A snapshot copy was
638- // created in track_modification_internal. Remove it and encode it.
639- // end_snapshot must not also process it, so we take it out of the map.
640- // snapshots is a separate DashMap from map, so holding `inner` across
641- // the remove and encode is safe — no lock ordering issue.
642- let snapshot = self
643- . shard
644- . storage
645- . snapshots
646- . remove ( & task_id)
647- . expect ( "The snapshot bit was set, so it must be in Snapshot state" )
648- . 1
649- . expect (
650- "snapshot entry for modified_during_snapshot task must contain a value" ,
651- ) ;
652-
653- let item = ( self . shard . process ) ( task_id, & snapshot, & mut self . buffer ) ;
654- // Clear the modified flags that were captured into the snapshot copy,
655- // then promote modified_during_snapshot → modified so the task stays
656- // dirty for the next snapshot cycle.
657- inner. flags . set_data_modified ( false ) ;
658- inner. flags . set_meta_modified ( false ) ;
659- inner. flags . set_new_task ( false ) ;
660- self . shard
661- . storage
662- . promote_during_snapshot_flags ( & mut inner, self . shard . shard_idx ) ;
663- return Some ( item) ;
664- }
665- }
666618 None
667619 }
668620}
@@ -704,20 +656,22 @@ mod tests {
704656 }
705657
706658 /// Regression test: a task modified before a snapshot and then modified *again* during
707- /// snapshot iteration must not trigger `debug_assert!(!inner.flags.any_modified())` in
708- /// `SnapshotShardIter:: next` .
659+ /// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
660+ /// modification forward to the next cycle .
709661 ///
710662 /// Sequence of events:
711663 /// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
712664 /// 2. `start_snapshot` puts us in snapshot mode.
713- /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` and
714- /// `any_modified_during_snapshot()=false` → task goes into the `modified` list.
715- /// 4. **Between scan and iteration**: `track_modification` is called on the task again. This is
716- /// the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of the
717- /// pre-snapshot state is created (carrying the modified bits) and stored in `snapshots`.
718- /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, finds
719- /// `any_modified_during_snapshot()=true`, clears the live modified flags (which were
720- /// captured into the snapshot), then asserts `!any_modified()` before promoting.
665+ /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
666+ /// `modified` list.
667+ /// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
668+ /// is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
669+ /// the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
670+ /// `data_modified_during_snapshot` is set.
671+ /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
672+ /// `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
673+ /// pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
674+ /// promotes `data_modified_during_snapshot → data_modified` for the next cycle.
721675 // `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
722676 // requiring a multi-threaded Tokio runtime.
723677 #[ tokio:: test( flavor = "multi_thread" ) ]
@@ -751,8 +705,8 @@ mod tests {
751705 assert ! ( guard. flags. data_modified_during_snapshot( ) )
752706 }
753707
754- // Step 5: consume the iterator. The iterator clears the live modified flags
755- // before the assert, encodes the snapshot copy , and promotes
708+ // Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
709+ // clears the live modified flags, removes the snapshots entry , and promotes
756710 // `data_modified_during_snapshot → data_modified` for the next cycle.
757711 let items: Vec < _ > = shards
758712 . into_iter ( )
@@ -765,7 +719,7 @@ mod tests {
765719
766720 {
767721 let guard = storage. access_mut ( task_id) ;
768- // Ending the snapshot should have promoted modified_during_snapshot → modified.
722+ // The iterator should have promoted modified_during_snapshot → modified.
769723 assert ! ( guard. flags. data_modified( ) ) ;
770724 }
771725
@@ -777,4 +731,73 @@ mod tests {
777731 "shard_modified_counts must be non-zero after promoting modified_during_snapshot"
778732 ) ;
779733 }
734+
735+ /// Regression test for the `(true, false)` during-snapshot case: a task modified in one
736+ /// category before a snapshot, then modified in a *different* category during snapshot
737+ /// iteration, must not panic and must carry both modifications forward correctly.
738+ ///
739+ /// Sequence of events:
740+ /// 1. Task meta is modified (meta_modified = true).
741+ /// 2. `start_snapshot` puts us in snapshot mode.
742+ /// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
743+ /// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
744+ /// modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
745+ /// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
746+ /// sees `None` in snapshots, encodes from live data (correct — live data for the
747+ /// unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
748+ /// flags, and promotes `data_modified_during_snapshot → data_modified`.
749+ #[ tokio:: test( flavor = "multi_thread" ) ]
750+ async fn modify_different_category_during_snapshot ( ) {
751+ let storage = Storage :: new ( 2 , true ) ;
752+ let task_id = non_transient_task ( 1 ) ;
753+
754+ // Step 1: modify meta only, outside snapshot mode.
755+ {
756+ let mut guard = storage. access_mut ( task_id) ;
757+ guard. track_modification ( SpecificTaskDataCategory :: Meta , "test" ) ;
758+ assert ! ( guard. flags. meta_modified( ) ) ;
759+ assert ! ( !guard. flags. data_modified( ) ) ;
760+ }
761+
762+ // Step 2: enter snapshot mode.
763+ let ( snapshot_guard, has_modifications) = storage. start_snapshot ( ) ;
764+ assert ! ( has_modifications) ;
765+
766+ // Step 3: take_snapshot — task goes into modified list (meta_modified = true).
767+ let shards = storage. take_snapshot ( snapshot_guard, & dummy_process) ;
768+
769+ // Step 4: modify data during snapshot. The `(true, false)` branch fires:
770+ // data was not previously modified, so snapshots gets a None entry.
771+ {
772+ let mut guard = storage. access_mut ( task_id) ;
773+ guard. track_modification ( SpecificTaskDataCategory :: Data , "test" ) ;
774+ assert ! ( guard. flags. data_modified_during_snapshot( ) ) ;
775+ assert ! ( !guard. flags. meta_modified_during_snapshot( ) ) ;
776+ }
777+
778+ // Step 5: consume the iterator — must not panic.
779+ let items: Vec < _ > = shards
780+ . into_iter ( )
781+ . flat_map ( |shard| shard. into_iter ( ) )
782+ . collect ( ) ;
783+
784+ assert_eq ! ( items. len( ) , 1 ) ;
785+ assert_eq ! ( items[ 0 ] . task_id, task_id) ;
786+
787+ {
788+ let guard = storage. access_mut ( task_id) ;
789+ // meta_modified was cleared by the iterator (it was the pre-snapshot flag).
790+ assert ! ( !guard. flags. meta_modified( ) ) ;
791+ // data_modified_during_snapshot was promoted to data_modified.
792+ assert ! ( guard. flags. data_modified( ) ) ;
793+ assert ! ( !guard. flags. data_modified_during_snapshot( ) ) ;
794+ }
795+
796+ // Next snapshot cycle must pick up the promoted data_modified.
797+ let ( _guard2, has_modifications) = storage. start_snapshot ( ) ;
798+ assert ! (
799+ has_modifications,
800+ "shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
801+ ) ;
802+ }
780803}
0 commit comments