Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
add29d3
Add support for objective function in SNES
stefanozampini Jun 7, 2026
d55a3bb
Apply suggestions from code review
stefanozampini Jun 7, 2026
d41cc7e
add FAS support
pbrubeck Jun 7, 2026
bf6f6c7
fix tolerances
stefanozampini Jun 8, 2026
5835420
Apply suggestions from code review
pbrubeck Jun 8, 2026
a4c3dc3
solve interface
pbrubeck Jun 9, 2026
6654120
WIP
stefanozampini Jun 10, 2026
0ea74a0
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 10, 2026
ef82c67
wip
stefanozampini Jun 10, 2026
1cb9b75
Apply suggestion from @stefanozampini
stefanozampini Jun 10, 2026
c3d50f2
Remove objective from linear problem
pbrubeck Jun 10, 2026
b899636
test passes with HYPRE
stefanozampini Jun 10, 2026
78d169f
wip
stefanozampini Jun 10, 2026
dd0e73b
Move bcdofs to SNESContext, add temporary postsolve to zero bc compon…
stefanozampini Jun 11, 2026
fe5dce6
WIP TEST
stefanozampini Jun 11, 2026
e49c317
add allen-cahn test
stefanozampini Jun 12, 2026
96b63ef
tests also newtonls
stefanozampini Jun 12, 2026
8105f56
Move bcdofs to bcs.py
pbrubeck Jun 12, 2026
f94cb29
Update docs/source/solving-interface.rst
stefanozampini Jun 13, 2026
66483db
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 13, 2026
64b9690
wip test
stefanozampini Jun 13, 2026
13895de
Merge branch 'stefanozampini/snes-objective' of github.com:firedrakep…
stefanozampini Jun 13, 2026
20adc3e
finalize tests
stefanozampini Jun 13, 2026
f0767f1
wip : add restrict failing tests
stefanozampini Jun 13, 2026
913190a
Fix restrict=True
pbrubeck Jun 13, 2026
ba19351
finalize tests
stefanozampini Jun 13, 2026
384e096
add restrict tests to allen cahn too
stefanozampini Jun 13, 2026
b98076d
wip
stefanozampini Jun 13, 2026
c03ec10
Cleanup
pbrubeck Jun 13, 2026
fc6c220
fixes
pbrubeck Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/solving-interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,29 @@ solving with.
# Use the approximate inverse of Jp to precondition solves
solve(a == L, ..., Jp=Jp)

Providing an objective functional
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

An objective functional, usually the energy,
can be provided to PETSc's `SNES`_ linesearch and trust-region methods.
For example, the code block below minimizes the Dirichlet energy functional
using a trust-region method with the Steihaugh-Toint method for the
trust-region subproblem and an initial trust region size of ``1e-4``.
The solver monitor will report residual norm, norm of the update, and
value of the objective functional.

.. code-block:: python3

E = 0.5 * inner(grad(u), grad(u))*dx + kappa**2 * cosh(u)*dx - inner(f, u)*dx
F = inner(grad(u), grad(v))*dx + kappa**2 * inner(sinh(u), v)*dx - inner(f, v)*dx
Comment thread
stefanozampini marked this conversation as resolved.
Outdated
# Solve an optimisation problem
sp = {'snes_type': 'newtontr',
'snes_tr_delta0': 1e-4,
"snes_monitor": "::ascii_info_detail",
'ksp_type': 'cg',
'pc_type': 'hypre'}
solve(F == 0, u, ..., objective=E, solver_parameters=sp, pre_apply_bcs=False)

Default solver options
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
43 changes: 42 additions & 1 deletion firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ufl
from ufl import as_ufl, as_tensor
from finat.ufl import VectorElement
from finat.ufl import MixedElement, VectorElement
import finat

import pyop2 as op2
Expand Down Expand Up @@ -746,3 +746,44 @@ def restricted_function_space(V, ids):
return spaces[0]
else:
return firedrake.MixedFunctionSpace(spaces, name=V.name)


def bcdofs(bc, ghost=True):
# Return the global dofs fixed by a DirichletBC
# in the numbering given by concatenation of all the
# subspaces of a mixed function space
Z = bc.function_space()
while Z.parent is not None:
Z = Z.parent

indices = bc._indices
offset = 0

for (i, idx) in enumerate(indices):
if isinstance(Z.ufl_element(), VectorElement):
offset += idx
assert i == len(indices)-1 # assert we're at the end of the chain
assert Z.sub(idx).block_size == 1
elif isinstance(Z.ufl_element(), MixedElement):
if ghost:
offset += sum(Z.sub(j).dof_count for j in range(idx))
else:
offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
else:
raise NotImplementedError("How are you taking a .sub?")

Z = Z.sub(idx)

if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
bs = Z.parent.block_size
start = 0
stop = 1
else:
bs = Z.block_size
start = 0
stop = bs
nodes = bc.nodes
if not ghost:
nodes = nodes[nodes < Z.dof_dset.size]

return np.concatenate([nodes*bs + j for j in range(start, stop)]) + offset
5 changes: 4 additions & 1 deletion firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,16 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
coefficient_mapping = {}

bcs = [self(bc, self, coefficient_mapping=coefficient_mapping) for bc in problem.bcs]
E = self(problem.E, self, coefficient_mapping=coefficient_mapping)
F = self(problem.F, self, coefficient_mapping=coefficient_mapping)
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping)
u = coefficient_mapping[problem.u_restrict]

fine = problem
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp,
objective=E,
is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
fine._coarse = problem
return problem
Expand Down
3 changes: 1 addition & 2 deletions firedrake/preconditioners/bddc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from itertools import repeat

from firedrake.preconditioners.base import PCBase
from firedrake.preconditioners.patch import bcdofs
from firedrake.preconditioners.facet_split import get_restriction_indices
from firedrake.petsc import PETSc
from firedrake.dmhooks import get_function_space, get_appctx
Expand All @@ -13,7 +12,7 @@
from functools import cached_property

from firedrake.parloops import par_loop, INC, READ
from firedrake.bcs import DirichletBC
from firedrake.bcs import bcdofs, DirichletBC
from firedrake.mesh import Submesh
from ufl import Form, H1, H2, JacobianDeterminant, dx, inner, replace
from finat.ufl import BrokenElement
Expand Down
2 changes: 1 addition & 1 deletion firedrake/preconditioners/fdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from itertools import chain, product
from firedrake.petsc import PETSc
from firedrake.preconditioners.base import PCBase
from firedrake.preconditioners.patch import bcdofs
from firedrake.preconditioners.pmg import (prolongation_matrix_matfree,
evaluate_dual,
get_permutation_to_nodal_elements,
cache_generate_code)
from firedrake.preconditioners.facet_split import restricted_dofs, split_dofs
from firedrake.bcs import bcdofs
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace
from firedrake.function import Function
Expand Down
43 changes: 1 addition & 42 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from firedrake.preconditioners.asm import validate_overlap
from firedrake.petsc import PETSc
import firedrake.cython.patchimpl
from firedrake.bcs import bcdofs
from firedrake.solving_utils import _SNESContext
from firedrake.utils import complex_mode
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
Expand All @@ -22,7 +23,6 @@
import operator
from functools import cached_property, partial
import numpy
from finat.ufl import VectorElement, MixedElement
from tsfc.ufl_utils import extract_firedrake_constants
import weakref
import petsctools
Expand Down Expand Up @@ -587,47 +587,6 @@ def make_patch_callables(form: ufl.Form, state: Function | None) -> tuple[
return cell_callable, interior_facet_callable, exterior_facet_callable


def bcdofs(bc, ghost=True):
# Return the global dofs fixed by a DirichletBC
# in the numbering given by concatenation of all the
# subspaces of a mixed function space
Z = bc.function_space()
while Z.parent is not None:
Z = Z.parent

indices = bc._indices
offset = 0

for (i, idx) in enumerate(indices):
if isinstance(Z.ufl_element(), VectorElement):
offset += idx
assert i == len(indices)-1 # assert we're at the end of the chain
assert Z.sub(idx).block_size == 1
elif isinstance(Z.ufl_element(), MixedElement):
if ghost:
offset += sum(Z.sub(j).dof_count for j in range(idx))
else:
offset += sum(Z.sub(j).dof_dset.size * Z.sub(j).block_size for j in range(idx))
else:
raise NotImplementedError("How are you taking a .sub?")

Z = Z.sub(idx)

if Z.parent is not None and isinstance(Z.parent.ufl_element(), VectorElement):
bs = Z.parent.block_size
start = 0
stop = 1
else:
bs = Z.block_size
start = 0
stop = bs
nodes = bc.nodes
if not ghost:
nodes = nodes[nodes < Z.dof_dset.size]

return numpy.concatenate([nodes*bs + j for j in range(start, stop)]) + offset


def select_entity(p, dm=None, exclude=None):
"""Filter entities based on some label.

Expand Down
13 changes: 7 additions & 6 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _solve_varproblem(*args, **kwargs):
"Solve variational problem a == L or F == 0"

# Extract arguments
eq, u, bcs, J, Jp, M, form_compiler_parameters, \
eq, u, bcs, J, Jp, objective, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, \
near_nullspace, \
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)
Expand All @@ -167,6 +167,8 @@ def _solve_varproblem(*args, **kwargs):
raise TypeError(f"Equation LHS must be a ufl.BaseForm, not a {type(eq.lhs).__name__}")

if len(eq.lhs.arguments()) == 2:
if objective is not None:
raise ValueError("The objective functional only makes sense for nonlinear problems.")
# Create linear variational problem
problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp,
form_compiler_parameters=form_compiler_parameters,
Expand All @@ -177,6 +179,7 @@ def _solve_varproblem(*args, **kwargs):
if eq.rhs != 0:
raise ValueError(f"RHS of nonlinear Equation must be `0`, not {eq.rhs}")
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
objective=objective,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
create_solver = vs.NonlinearVariationalSolver
Expand Down Expand Up @@ -279,7 +282,7 @@ def _extract_args(*args, **kwargs):
"Extraction of arguments for _solve_varproblem"

# Check for use of valid kwargs
valid_kwargs = ["bcs", "J", "Jp", "M",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M had been used before as the objective, however we were completely ignoring it

valid_kwargs = ["bcs", "J", "Jp", "objective",
"form_compiler_parameters", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"options_prefix", "appctx", "restrict", "pre_apply_bcs"]
Expand Down Expand Up @@ -314,9 +317,7 @@ def _extract_args(*args, **kwargs):
Jp = kwargs.get("Jp", None)

# Extract functional
M = kwargs.get("M", None)
if M is not None and not isinstance(M, ufl.Form):
raise RuntimeError("Expecting goal functional M to be a UFL Form")
objective = kwargs.get("objective", None)

nullspace = kwargs.get("nullspace", None)
nullspace_T = kwargs.get("transpose_nullspace", None)
Expand All @@ -328,7 +329,7 @@ def _extract_args(*args, **kwargs):
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
return eq, u, bcs, J, Jp, objective, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, restrict, pre_apply_bcs

Expand Down
56 changes: 56 additions & 0 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pyop2 import op2
from firedrake import dmhooks
from firedrake.bcs import bcdofs
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.matrix import MatrixBase
Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(self, problem,
self.pmatfree = pmatfree
self.F = problem.F
self.J = problem.J
self.E = problem.E

# For Jp to equal J, bc.Jp must equal bc.J for all EquationBC objects.
Jp_eq_J = problem.Jp is None and all(bc.Jp_eq_J for bc in problem.bcs)
Expand Down Expand Up @@ -261,6 +263,12 @@ def __init__(self, problem,

self.F -= problem.compute_bc_lifting(self.J, self._bc_residual)

self._assemble_objective = lambda *args, **kwargs: args
if self.E:
self._assemble_objective = get_assembler(self.E,
form_compiler_parameters=self.fcp,
).assemble

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=pre_apply_bcs,
Expand All @@ -277,6 +285,18 @@ def __init__(self, problem,
self._coefficient_mapping = None
self._transfer_manager = transfer_manager

@cached_property
def bc_iset(self):
if self._problem.restrict:
return None
bcs = self._problem.dirichlet_bcs()
V = self._x.function_space()
bc_nodes = numpy.unique(numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs], dtype=PETSc.IntType))
bc_nodes = V.dof_dset.lgmap.apply(bc_nodes)
bc_is = PETSc.IS().createGeneral(bc_nodes, comm=V.comm)
bc_is.sort()
return bc_is

def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs):
"""Reconstruct this _SNESContext instance with new arguments."""
problem = problem or self._problem
Expand Down Expand Up @@ -342,6 +362,12 @@ def transfer_manager(self, manager):
raise ValueError("Must set transfer manager before first use.")
self._transfer_manager = manager

def set_objective(self, snes):
if self._problem.E:
snes.setObjective(self.form_objective)
else:
snes.setObjective(None)
Comment thread
stefanozampini marked this conversation as resolved.

def set_function(self, snes):
r"""Set the residual evaluation function"""
with self._F.dat.vec_wo as v:
Expand All @@ -360,6 +386,9 @@ def set_nullspace(self, nullspace, ises=None, transpose=False, near=False):
if ises is not None:
nullspace._apply(ises, transpose=transpose, near=near)

def set_ksp_postsolve(self, snes):
snes.getKSP().setPostSolve(self.ksp_postsolve)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this overwrite an existing user post-solve callback? Or will they both be called in that case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a different one. The existing ones are SNES related (pre/post Jacobian/Function)

@JHopeCollins JHopeCollins Jun 13, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the ones passed to the nlvs, the ones on the actual ksp. A user might have set one, or a particular ksp type might define one itself

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary fix to firedrake's model of pre_apply_bcs=True, which I do not fully understand (or agree with :-) ). May be removed if we understand what's going wrong with the bc model and inexact solves

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the ones passed to the nlvs the ones on the actual ksp
A user might have set one, or a particular ksp type might define one itself

This is set as the default when firedrake creates the SNES. Users can do whatever they want afterward. Add their own, or change the KSP. And no, KSPs don't have their own since each KSP implements its own Solve; these are user-specific callbacks

As I said, this is temporary and can be removed once we understand what is going wrong with pre_apply_bcs=True

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the pre_apply_bcs=True is flawed in the SNES sense (does not let you linearize around a state that does not satisfy the BCs), but also in the KSP sense (there is no way to make sure that the PC would not introduce coupling between the BC and the interior, i.e. the PC needs to invert the identity block exactly). The short-circuit pre_apply_bcs=restrict, fixes the KSP issue, but not the SNES one.

What do you mean. With restrict you do not have BC dofs at all in the system, so it should not cause problems either to KSP or SNES, since you will just solve the proper linear system only approximatively

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For option 2, we would change the default to pre_apply_bcs=restrict=False and deprecate the argument pre_apply_bcs from the next release onwards, but still accept two separate arguments for restrict and pre_apply_bcs and make sure that they are equal. We then refactor the internal code to have a unified restrict/pre_apply_bcs flag. Finally, remove the pre_apply_bcs argument on the following release.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean. With restrict you do not have BC dofs at all in the system, so it should not cause problems either to KSP or SNES, since you will just solve the proper linear system only approximatively

The SNES issue is unavoidable and inherent to the pre_apply_bcs=True model. It might result in NaN at the first residual on some nonlinear problems when the initial guess is not consistent with the BCs. Imposing the BCs for such guesses introduces very sharp gradients. See this PR description, and the test that we added for pre_apply_bcs=False.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #5177 to continue with this discussion


@PETSc.Log.EventDecorator()
def split(self, fields):
from firedrake import replace, as_vector, split, zero
Expand Down Expand Up @@ -449,6 +478,26 @@ def split(self, fields):
splits.append(self.reconstruct(new_problem, options_prefix=options_prefix))
return self._splits.setdefault(tuple(fields), splits)

@staticmethod
def form_objective(snes, X):
r"""Form the objective for this problem

:arg snes: a PETSc SNES object
:arg X: the current guess (a Vec)
"""
dm = snes.getDM()
ctx = dmhooks.get_appctx(dm)
# X may not be the same vector as the vec behind self._x, so
# copy guess in from X.
with ctx._x.dat.vec_wo as v:
X.copy(v)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? ctx._x is the solution according to Firedrake (but SNES has a different one) can we treat ctx._x as a buffer?

@stefanozampini stefanozampini Jun 13, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I copied that from the residual evaluation routine. At the end of solve, you have work.copy(u), with u = problem.u_restrict.dat.vec and work the Vec SNES is using for the solution. From the point of view of the solve, this is ok. But if your ctx._x is supposed to be immutable for a given Newton step, then overwriting it can be a problem, yes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can form_objective be called within a single Newton step? i.e. by the KSP? If so then I do not think that this is safe because for anything matrix-free we assume that ctx._x holds the current state.

I think you just need to save the state in ctx._x to a temporary buffer at entry to this function and restore it at exit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle it should not happen, but I agree with you better be safe. Make a suggestion in the diff and I will merge it in


# Apply DirichletBC on the solution
for bc in ctx._problem.dirichlet_bcs():
bc.apply(ctx._x)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Apply DirichletBC on the solution
for bc in ctx._problem.dirichlet_bcs():
bc.apply(ctx._x)
# Apply DirichletBC on the solution
if ctx.pre_apply_bcs:
for bc in ctx._problem.dirichlet_bcs():
bc.apply(ctx._x)

@stefanozampini stefanozampini Jun 13, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think what you are proposing is correct, but I also don't think what I have in the current PR is 100% correct. In all my experience with nonlinear solvers, you always inject the correct BCs whenever you evaluate objective functions or residuals. Then, in case of residuals, you modify afterward the part of the vector that is associated with bc, because there you solve a different set of equations, i.e.

F(u) = 0 on \Omega
u - g = 0 on \partial\Omega

As a byproduct, if you do Newton, the boundary conditions will be exact after the first full step (that may not be the first Newton step if you use linesearch). See e.g. how I do it in MFEM https://github.com/mfem/mfem/blob/master/linalg/petsc.cpp#L4996

If ctx._x is supposed to be immutable during each phases of a given SNES step, then you need to rethink the firedrake logic a bit


return ctx._assemble_objective()

@staticmethod
def form_function(snes, X, F):
r"""Form the residual for this problem
Expand Down Expand Up @@ -547,6 +596,13 @@ def compute_operators(ksp, J, P):
assert P.handle == ctx._pjac.petscmat.handle
ctx._assemble_pjac(ctx._pjac)

@staticmethod
def ksp_postsolve(ksp, rhs, sol):
dm = ksp.getDM()
ctx = dmhooks.get_appctx(dm)
if ctx.pre_apply_bcs and not ctx._problem.restrict:
sol.isset(ctx.bc_iset, 0)

@cached_property
def _assembler_jac(self):
from firedrake.assemble import get_assembler
Expand Down
Loading
Loading