Skip to content

Fix SingleMemoryStorageSchedule checkpoint clearing during forward replay#248

Open
sghelichkhani wants to merge 2 commits into
dolfin-adjoint:masterfrom
sghelichkhani:sghelichkhani/singlemem-checkpoint-clearing
Open

Fix SingleMemoryStorageSchedule checkpoint clearing during forward replay#248
sghelichkhani wants to merge 2 commits into
dolfin-adjoint:masterfrom
sghelichkhani:sghelichkhani/singlemem-checkpoint-clearing

Conversation

@sghelichkhani

@sghelichkhani sghelichkhani commented Feb 15, 2026

Copy link
Copy Markdown
Contributor

Summary

Fixes #211. Fixes #260.

Both issues come from the same spot: during forward replay, the SingleMemoryStorageSchedule branch of the checkpoint clearing decides at a variable's last forward use whether its checkpoint can be discarded, and got this wrong in two ways.

For #211, the condition consulted the previous step's adjoint_dependencies, which are only fully populated during the reverse pass, so a long-range dependency (a variable reused, say, every third step) was cleared during replay and the reverse pass produced a wrong gradient. The first commit here replaced that with an identity check keeping superseded block variables alive. But that still cleared variables that remain the live block variable of their Function, relying on saved_output falling back to the live object — which holds the value from taping, not from the replay. As #260 shows, the adjoint of the last-use step still needs the variable as a linearisation point, so re-evaluating the functional at a new control before calling derivative() corrupts the gradient, which is precisely what every optimisation loop does.

The end state is a single condition covering both: a variable visited here is by construction a dependency of a block in this step, so its checkpoint is only cleared once the step's adjoint dependencies have been revised by a reverse pass and the variable is provably not among them. Before that everything is kept, which is exactly what "all adjoint dependencies are stored in memory" promises. The clear still goes through the public checkpoint setter, so the is_control guard from #257 protects control values on this path.

Test

Against the #211 MFE the gradient matches the unscheduled tape (3205.30164... in both cases, where master gives 73765.11...). Against the #260 MFE the gradient at a new control is 187.4048 with and without the schedule, where master gives 140.5536; the step-1-last-use variant and a second evaluate/derivative cycle also agree. A control's value still survives clearing under this schedule (four timesteps of J = sum_k m**2 at m = 2 give J = dJ/dm = 16), and all 218 pyadjoint unit tests pass. Regression tests for both issues live in firedrakeproject/firedrake#5168.

@sghelichkhani sghelichkhani force-pushed the sghelichkhani/singlemem-checkpoint-clearing branch from 436936d to a1e40b7 Compare June 12, 2026 00:16
@sghelichkhani

Copy link
Copy Markdown
Contributor Author

I've rebased this onto master now that #257 is in, since both PRs touched the same line. The original version of this branch predated #257 and still wrote var._checkpoint = None directly, which would have bypassed the is_control guard and quietly undone #257 on the SingleMemory path. That was exactly the concern Josh raised over on firedrakeproject/firedrake#5093. The rebased commit keeps the new identity condition from this PR but clears through the public setter, so the line now reads var.checkpoint = None.

I've verified both behaviours locally against current firedrake main: the #211 MFE now gives an identical gradient with and without SingleMemoryStorageSchedule (3205.30 versus the 73765.11 you get on master), and a control's value survives clearing under the SingleMemory schedule, i.e. the firedrake #5082 scenario stays fixed on this code path too. The firedrake checkpointing tests show no change in behaviour relative to master.

sghelichkhani added a commit to sghelichkhani/firedrake that referenced this pull request Jun 12, 2026
…encies

A block variable reused only every third timestep is absent from the
immediately preceding step's adjoint_dependencies, so the checkpoint
clearing for SingleMemoryStorageSchedule discarded it during forward
replay and the reverse pass produced a wrong gradient (73765 instead of
3205 in this test). Companion to dolfin-adjoint/pyadjoint#248, which
fixes the clearing condition; see dolfin-adjoint/pyadjoint#211 for the
original report.
sghelichkhani added a commit to sghelichkhani/firedrake that referenced this pull request Jun 12, 2026
Drop this commit (restoring the pyadjoint-ad version pin) once
dolfin-adjoint/pyadjoint#248 is merged and released.
@sghelichkhani

Copy link
Copy Markdown
Contributor Author

Companion test on the firedrake side is now up as firedrakeproject/firedrake#5168: a regression test for the #211 staggered-dependency pattern that fails on current master and passes with this branch, with CI temporarily pointed here.

…dencies are revised

A variable visited by the clear-down loop at its last forward use is by
construction a dependency of a block in that step, so the adjoint of the
step may still need it as a linearisation point. Clearing it and relying
on saved_output falling back to the live function returns the stale
taped value whenever the functional has been re-evaluated at a new
control. Keep everything until a reverse pass has revised the step's
adjoint dependencies, then clear only what is provably not needed.

Fixes dolfin-adjoint#260.
sghelichkhani added a commit to sghelichkhani/firedrake that referenced this pull request Jun 12, 2026
A variable that is never redefined keeps the live block variable of its
Function, and the clearing for SingleMemoryStorageSchedule discarded its
checkpoint at its last forward use even though the adjoint of that step
still needs it as a linearisation point. The saved_output fallback to
the live Function then supplies the taping-time value, corrupting the
gradient whenever the functional is re-evaluated at a new control first
(140.5536 instead of 187.4048 in this test). Companion to
dolfin-adjoint/pyadjoint#248; see dolfin-adjoint/pyadjoint#260 for the
report.

@Ig-dolci Ig-dolci left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks correct to me. Gating clearing on _revised_adj_deps is the right correctness-first approach, and using the public checkpoint setter preserves the is_control guard from #257.

The code comment already clearly explains that checkpoints are retained until a reverse pass revises the adjoint dependencies. Perhaps one minor documentation suggestion is to mention the resulting memory trade-off: forward-only recomputations retain the conservative dependency set because more precise clearing is enabled only after the first reverse.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants