Fix SingleMemoryStorageSchedule checkpoint clearing during forward replay#248
Conversation
436936d to
a1e40b7
Compare
|
I've rebased this onto master now that #257 is in, since both PRs touched the same line. The original version of this branch predated #257 and still wrote I've verified both behaviours locally against current firedrake main: the #211 MFE now gives an identical gradient with and without |
…encies A block variable reused only every third timestep is absent from the immediately preceding step's adjoint_dependencies, so the checkpoint clearing for SingleMemoryStorageSchedule discarded it during forward replay and the reverse pass produced a wrong gradient (73765 instead of 3205 in this test). Companion to dolfin-adjoint/pyadjoint#248, which fixes the clearing condition; see dolfin-adjoint/pyadjoint#211 for the original report.
Drop this commit (restoring the pyadjoint-ad version pin) once dolfin-adjoint/pyadjoint#248 is merged and released.
|
Companion test on the firedrake side is now up as firedrakeproject/firedrake#5168: a regression test for the #211 staggered-dependency pattern that fails on current master and passes with this branch, with CI temporarily pointed here. |
…dencies are revised A variable visited by the clear-down loop at its last forward use is by construction a dependency of a block in that step, so the adjoint of the step may still need it as a linearisation point. Clearing it and relying on saved_output falling back to the live function returns the stale taped value whenever the functional has been re-evaluated at a new control. Keep everything until a reverse pass has revised the step's adjoint dependencies, then clear only what is provably not needed. Fixes dolfin-adjoint#260.
A variable that is never redefined keeps the live block variable of its Function, and the clearing for SingleMemoryStorageSchedule discarded its checkpoint at its last forward use even though the adjoint of that step still needs it as a linearisation point. The saved_output fallback to the live Function then supplies the taping-time value, corrupting the gradient whenever the functional is re-evaluated at a new control first (140.5536 instead of 187.4048 in this test). Companion to dolfin-adjoint/pyadjoint#248; see dolfin-adjoint/pyadjoint#260 for the report.
Ig-dolci
left a comment
There was a problem hiding this comment.
The fix looks correct to me. Gating clearing on _revised_adj_deps is the right correctness-first approach, and using the public checkpoint setter preserves the is_control guard from #257.
The code comment already clearly explains that checkpoints are retained until a reverse pass revises the adjoint dependencies. Perhaps one minor documentation suggestion is to mention the resulting memory trade-off: forward-only recomputations retain the conservative dependency set because more precise clearing is enabled only after the first reverse.
Summary
Fixes #211. Fixes #260.
Both issues come from the same spot: during forward replay, the
SingleMemoryStorageSchedulebranch of the checkpoint clearing decides at a variable's last forward use whether its checkpoint can be discarded, and got this wrong in two ways.For #211, the condition consulted the previous step's
adjoint_dependencies, which are only fully populated during the reverse pass, so a long-range dependency (a variable reused, say, every third step) was cleared during replay and the reverse pass produced a wrong gradient. The first commit here replaced that with an identity check keeping superseded block variables alive. But that still cleared variables that remain the live block variable of their Function, relying onsaved_outputfalling back to the live object — which holds the value from taping, not from the replay. As #260 shows, the adjoint of the last-use step still needs the variable as a linearisation point, so re-evaluating the functional at a new control before callingderivative()corrupts the gradient, which is precisely what every optimisation loop does.The end state is a single condition covering both: a variable visited here is by construction a dependency of a block in this step, so its checkpoint is only cleared once the step's adjoint dependencies have been revised by a reverse pass and the variable is provably not among them. Before that everything is kept, which is exactly what "all adjoint dependencies are stored in memory" promises. The clear still goes through the public
checkpointsetter, so theis_controlguard from #257 protects control values on this path.Test
Against the #211 MFE the gradient matches the unscheduled tape (3205.30164... in both cases, where master gives 73765.11...). Against the #260 MFE the gradient at a new control is 187.4048 with and without the schedule, where master gives 140.5536; the step-1-last-use variant and a second evaluate/derivative cycle also agree. A control's value still survives clearing under this schedule (four timesteps of
J = sum_k m**2atm = 2giveJ = dJ/dm = 16), and all 218 pyadjoint unit tests pass. Regression tests for both issues live in firedrakeproject/firedrake#5168.