Skip to content

Try to handle solver caching for linear forms (-> #4638)#5158

Merged
JHopeCollins merged 5 commits into
JHopeCollins/nlvs-hessian-fixfrom
angus-g/4638-linear-form
Jun 16, 2026
Merged

Try to handle solver caching for linear forms (-> #4638)#5158
JHopeCollins merged 5 commits into
JHopeCollins/nlvs-hessian-fixfrom
angus-g/4638-linear-form

Conversation

@angus-g

@angus-g angus-g commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

When we try to cache the Hessian and TLM solvers for a linear form (the L2 Riesz representation), there are a few assumptions about the resulting derivative forms that fail:

  • expand_derivatives(d2Fdu2) breaks when it tries to expand a ZeroBaseForm, so we just skip that step (with a try/except that needs to be fixed!)
  • taking derivative(action(F, v), L) doesn't work, so we switch to action(derivative(F, L), v)
  • for a linear problem, d2Fdudm is zero, but derivative(dFdm_adj) fails, so we simply skip adding to the Hessian forms add explicit zero forms in that case
  • the TLM solver fails unless we explicitly expand derivative(F, u) from the linear form

@angus-g angus-g added the base:main Run this PR using a main (dev) build label Jun 9, 2026
@angus-g angus-g force-pushed the angus-g/4638-linear-form branch 2 times, most recently from b293800 to 0fe0ac4 Compare June 11, 2026 12:23
@pbrubeck

Copy link
Copy Markdown
Contributor

Some of these workarounds indicate that there's work to be done in UFL. Could you raise an issue in UFL for these failing cases, providing an MFE?

Comment thread firedrake/adjoint_utils/variational_solver.py Outdated
Comment thread firedrake/adjoint_utils/variational_solver.py Outdated
@angus-g angus-g force-pushed the angus-g/4638-linear-form branch 3 times, most recently from 7b31a77 to 8a85c16 Compare June 16, 2026 01:15
@angus-g angus-g force-pushed the JHopeCollins/nlvs-hessian-fix branch from 6f6fd15 to 440f386 Compare June 16, 2026 02:01
angus-g added 2 commits June 16, 2026 12:01
These seem to be a somewhat historical relic from when both NLVS
and LVS code paths could end up in the same solver block. The reason
for excluding the linear case is not clear, perhaps because the RHS
of the problem is involved in some way?

Regardless, we always enter the solver block normalised into a
NLVP which is always "nonlinear" according to the original
condition:

linear = (
    isinstance(lhs, ufl.Form) and
    isinstance(rhs, (ufl.Form, ufl.Cofunction)
)

Without this change, we can erroneously evaluate the adjoint component
corresponding to the initial guess.
I think this approach is more consistent with the underlying
maths, but my UFL is not particularly strong. This way allows
expand_derivatives to simplify expressions as required, and
the slot treatments of arguments is consistent across the
cached forms.
@angus-g angus-g force-pushed the angus-g/4638-linear-form branch from 8a85c16 to fa1af48 Compare June 16, 2026 02:02
@angus-g

angus-g commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

Alright, I went back through the math and rewrote the expressions to step around the hacks. Seems much cleaner and everything seems to pass now. Next hurdle is the forward cache on the full waveform inversion demo.

Traceback
/opt/firedrake/firedrake/adjoint_utils/variational_solver.py:425: in wrapper
    self._ad_forward_cache,
    ^^^^^^^^^^^^^^^^^^^^^^
/usr/lib/python3.12/functools.py:995: in __get__
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
/usr/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
/opt/firedrake/firedrake/adjoint_utils/variational_solver.py:169: in _ad_forward_cache
    nlvs = NonlinearVariationalSolver(
petsc4py/PETSc/Log.pyx:250: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
petsc4py/PETSc/Log.pyx:251: in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
    ???
/usr/lib/python3.12/contextlib.py:81: in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
/opt/firedrake/firedrake/adjoint_utils/variational_solver.py:101: in wrapper
    init(self, problem, *args, **kwargs)
/opt/firedrake/firedrake/variational_solver.py:308: in __init__
    ctx.set_jacobian(self.snes)
/opt/firedrake/firedrake/solving_utils.py:351: in set_jacobian
    snes.setJacobian(self.form_jacobian, J=self._jac.petscmat,
                                           ^^^^^^^^^
/usr/lib/python3.12/functools.py:995: in __get__
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
/opt/firedrake/firedrake/solving_utils.py:559: in _jac
    return self._assembler_jac.allocate()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/opt/firedrake/firedrake/assemble.py:364: in allocate
    return MatrixFreeAssembler(self._form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class 'firedrake.assemble.MatrixFreeAssembler'>
args = (FormSum([1*Form([Integral(CoefficientDerivative(Product(Argument(WithGeometry(FunctionSpace(<firedrake.mesh.MeshTopol... name=None), Mesh(VectorElement(FiniteElement('Lagrange', triangle, 1), dim=2), 1)), 1, None),)), ExprMapping(*()))]),)
kwargs = {'appctx': {'form_compiler_parameters': None, 'state': Coefficient(WithGeometry(FunctionSpace(<firedrake.mesh.MeshTopo...range', triangle, 1), dim=2), 1)), 1578)}, 'bcs': (), 'form_compiler_parameters': {}, 'options_prefix': 'firedrake_3_'}
form = FormSum([1*Form([Integral(CoefficientDerivative(Product(Argument(WithGeometry(FunctionSpace(<firedrake.mesh.MeshTopolo...), name=None), Mesh(VectorElement(FiniteElement('Lagrange', triangle, 1), dim=2), 1)), 1, None),)), ExprMapping(*()))])

    def __new__(cls, *args, **kwargs):
        form = args[0]
        if not isinstance(form, (ufl.Form, slate.TensorBase)):
>           raise TypeError(f"The first positional argument must be of ufl.Form or slate.TensorBase: got {type(form)} ({form})")
E           TypeError: The first positional argument must be of ufl.Form or slate.TensorBase: got <class 'ufl.form.FormSum'> (1*d/dfj { v_0 * w₁₅₇₈ * 1 / w₁₅₇₅ * w₁₅₇₅ / [4.e-06] }, with fh=ExprList(*(w₁₅₇₈,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * dx(<Mesh #1>[everywhere], {'quadrature_rule': QuadratureRule(KMVPointSet(array([[0., 0.],
E                  [1., 0.],
E                  [0., 1.]])), ndarray([0.1666666666667, 0.1666666666667, 0.1666666666667]), UFCTriangle(2, ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0)), {0: {0: (0,), 1: (1,), 2: (2,)}, 1: {0: (1, 2), 1: (0, 2), 2: (0, 1)}, 2: {0: (0, 1, 2)}}), (None,))}, {})
E             +  d/dfj { -1 * -1 * v_0 * (w₁₅₈₂ + -1 * 2.0 * w₁₅₈₀) * 1 / w₁₅₇₅ * w₁₅₇₅ / [4.e-06] }, with fh=ExprList(*(w₁₅₇₈,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * dx(<Mesh #1>[everywhere], {'quadrature_rule': QuadratureRule(KMVPointSet(array([[0., 0.],
E                  [1., 0.],
E                  [0., 1.]])), ndarray([0.1666666666667, 0.1666666666667, 0.1666666666667]), UFCTriangle(2, ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0)), {0: {0: (0,), 1: (1,), 2: (2,)}, 1: {0: (1, 2), 1: (0, 2), 2: (0, 1)}, 2: {0: (0, 1, 2)}}), (None,))}, {})
E             +  d/dfj { -1 * -1 * (sum_{i_{16}} (grad(v_0))[i_{16}] * (grad(w₁₅₈₀))[i_{16}] ) }, with fh=ExprList(*(w₁₅₇₈,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * dx(<Mesh #1>[everywhere], {'quadrature_rule': QuadratureRule(KMVPointSet(array([[0., 0.],
E                  [1., 0.],
E                  [0., 1.]])), ndarray([0.1666666666667, 0.1666666666667, 0.1666666666667]), UFCTriangle(2, ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0)), {0: {0: (0,), 1: (1,), 2: (2,)}, 1: {0: (1, 2), 1: (0, 2), 2: (0, 1)}, 2: {0: (0, 1, 2)}}), (None,))}, {})
E             +  d/dfj { -1 * -1 * v_0 * (w₁₅₈₀ + -1 * w₁₅₈₂) / 0.002 * 1 / w₁₅₇₅ }, with fh=ExprList(*(w₁₅₇₈,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * ds(<Mesh #1>[everywhere], {}, {})
E             +  -1*d/dfj { cofunction_1031 }, with fh=ExprList(*(w₁₅₇₈,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()))

/opt/firedrake/firedrake/assemble.py:953: TypeError

Looks like explicitly providing the Jacobian from a linear problem (with appropriate substitutions) works, but it does highlight that we need to be a bit clever setting up the cached solvers!

The full waveform inversion demo fails in the matrix free assembler
when trying to deal with the automatically differentiated form
of the NLVP. However, it's actually a linear problem, so we know
the Jacobian. Indeed, the Jacobian was originally passed to the NLVP
construction from the LVP in the first place, so we should mirro this.
@angus-g angus-g marked this pull request as ready for review June 16, 2026 04:57
Comment thread firedrake/adjoint_utils/variational_solver.py Outdated
Comment thread firedrake/adjoint_utils/variational_solver.py Outdated
angus-g and others added 2 commits June 16, 2026 11:09
In addition to caching Jp if it exists, we unconditionally
cache the Jacobian from the NLVP.

Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
@JHopeCollins JHopeCollins merged commit 442f29e into JHopeCollins/nlvs-hessian-fix Jun 16, 2026
7 checks passed
@JHopeCollins JHopeCollins deleted the angus-g/4638-linear-form branch June 16, 2026 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

base:main Run this PR using a main (dev) build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants