Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
add29d3
Add support for objective function in SNES
stefanozampini Jun 7, 2026
d55a3bb
Apply suggestions from code review
stefanozampini Jun 7, 2026
d41cc7e
add FAS support
pbrubeck Jun 7, 2026
bf6f6c7
fix tolerances
stefanozampini Jun 8, 2026
5835420
Apply suggestions from code review
pbrubeck Jun 8, 2026
a4c3dc3
solve interface
pbrubeck Jun 9, 2026
6654120
WIP
stefanozampini Jun 10, 2026
0ea74a0
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 10, 2026
ef82c67
wip
stefanozampini Jun 10, 2026
1cb9b75
Apply suggestion from @stefanozampini
stefanozampini Jun 10, 2026
c3d50f2
Remove objective from linear problem
pbrubeck Jun 10, 2026
b899636
test passes with HYPRE
stefanozampini Jun 10, 2026
78d169f
wip
stefanozampini Jun 10, 2026
dd0e73b
Move bcdofs to SNESContext, add temporary postsolve to zero bc compon…
stefanozampini Jun 11, 2026
fe5dce6
WIP TEST
stefanozampini Jun 11, 2026
e49c317
add allen-cahn test
stefanozampini Jun 12, 2026
96b63ef
tests also newtonls
stefanozampini Jun 12, 2026
8105f56
Move bcdofs to bcs.py
pbrubeck Jun 12, 2026
f94cb29
Update docs/source/solving-interface.rst
stefanozampini Jun 13, 2026
66483db
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 13, 2026
64b9690
wip test
stefanozampini Jun 13, 2026
13895de
Merge branch 'stefanozampini/snes-objective' of github.com:firedrakep…
stefanozampini Jun 13, 2026
20adc3e
finalize tests
stefanozampini Jun 13, 2026
f0767f1
wip : add restrict failing tests
stefanozampini Jun 13, 2026
913190a
Fix restrict=True
pbrubeck Jun 13, 2026
ba19351
finalize tests
stefanozampini Jun 13, 2026
384e096
add restrict tests to allen cahn too
stefanozampini Jun 13, 2026
b98076d
wip
stefanozampini Jun 13, 2026
c03ec10
Cleanup
pbrubeck Jun 13, 2026
fc6c220
fixes
pbrubeck Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/solving-interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,29 @@ solving with.
# Use the approximate inverse of Jp to precondition solves
solve(a == L, ..., Jp=Jp)

Providing an objective functional
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

An objective functional, usually the energy,
can be provided to PETSc's `SNES`_ linesearch and trust-region methods.
For example, the code block below minimizes the Dirichlet energy functional
using a trust-region method with the Steihaugh-Toint method for the
trust-region subproblem and an initial trust region size of ``1e-4``.
The solver monitor will report residual norm, norm of the update, and
value of the objective functional.

.. code-block:: python3

E = 0.5 * inner(grad(u), grad(u))*dx + kappa**2 * cosh(u)*dx - inner(f, u)*dx
F = inner(grad(u), grad(v))*dx + kappa**2 * inner(sinh(u), v)*dx - inner(f, v)*dx
Comment thread
stefanozampini marked this conversation as resolved.
Outdated
# Solve an optimisation problem
sp = {'snes_type': 'newtontr',
'snes_tr_delta0': 1e-4,
"snes_monitor": "::ascii_info_detail",
'ksp_type': 'cg',
'pc_type': 'gamg'}
solve(F == 0, u, ..., objective=E, solver_parameters=sp, pre_apply_bcs=False)

Default solver options
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, A, *, P=None, **kwargs):
self.b = Cofunction(test.function_space().dual())

problem = LinearVariationalProblem(A, self.b, self.x, aP=P,
objective=kwargs.pop("objective", None),
form_compiler_parameters=A.form_compiler_parameters,
constant_jacobian=True)
super().__init__(problem, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,16 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
coefficient_mapping = {}

bcs = [self(bc, self, coefficient_mapping=coefficient_mapping) for bc in problem.bcs]
E = self(problem.E, self, coefficient_mapping=coefficient_mapping)
F = self(problem.F, self, coefficient_mapping=coefficient_mapping)
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping)
u = coefficient_mapping[problem.u_restrict]

fine = problem
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp,
objective=E,
is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
fine._coarse = problem
return problem
Expand Down
25 changes: 15 additions & 10 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _solve_varproblem(*args, **kwargs):
"Solve variational problem a == L or F == 0"

# Extract arguments
eq, u, bcs, J, Jp, M, form_compiler_parameters, \
eq, u, bcs, J, Jp, objective, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, \
near_nullspace, \
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)
Expand All @@ -169,6 +169,7 @@ def _solve_varproblem(*args, **kwargs):
if len(eq.lhs.arguments()) == 2:
# Create linear variational problem
problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp,
objective=objective,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
create_solver = vs.LinearVariationalSolver
Expand All @@ -177,6 +178,7 @@ def _solve_varproblem(*args, **kwargs):
if eq.rhs != 0:
raise ValueError(f"RHS of nonlinear Equation must be `0`, not {eq.rhs}")
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
objective=objective,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
create_solver = vs.NonlinearVariationalSolver
Expand All @@ -201,6 +203,7 @@ def _la_solve(A, x, b, **kwargs):
:kwarg P: an optional :class:`~.MatrixBase` to construct any
preconditioner from; if none is supplied ``A`` is
used to construct the preconditioner.
:kwarg objective: a form used for line-search.
:kwarg solver_parameters: optional solver parameters.
:kwarg nullspace: an optional :class:`.VectorSpaceBasis` (or
:class:`.MixedVectorSpaceBasis`) spanning the null space of
Expand Down Expand Up @@ -233,15 +236,18 @@ def _la_solve(A, x, b, **kwargs):

_la_solve(A, x, b, solver_parameters=parameters_dict)."""

(P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace,
(P, bcs, objective, solver_parameters,
nullspace, nullspace_T, near_nullspace,
options_prefix, pre_apply_bcs,
) = _extract_linear_solver_args(A, x, b, **kwargs)

if bcs is not None:
raise RuntimeError("It is no longer possible to apply or change boundary conditions after assembling the matrix `A`; pass any necessary boundary conditions to `assemble` when assembling `A`.")

appctx = solver_parameters.get("appctx", {})
solver = ls.LinearSolver(A=A, P=P, solver_parameters=solver_parameters,
solver = ls.LinearSolver(A=A, P=P,
objective=objective,
solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
Expand All @@ -252,7 +258,7 @@ def _la_solve(A, x, b, **kwargs):


def _extract_linear_solver_args(*args, **kwargs):
valid_kwargs = ["P", "bcs", "solver_parameters", "nullspace",
valid_kwargs = ["P", "bcs", "objective", "solver_parameters", "nullspace",
"transpose_nullspace", "near_nullspace", "options_prefix",
"pre_apply_bcs"]
if len(args) != 3:
Expand All @@ -265,21 +271,22 @@ def _extract_linear_solver_args(*args, **kwargs):

P = kwargs.get("P", None)
bcs = kwargs.get("bcs", None)
objective = kwargs.get("objective", None)
solver_parameters = kwargs.get("solver_parameters", {}) or {}
nullspace = kwargs.get("nullspace", None)
nullspace_T = kwargs.get("transpose_nullspace", None)
near_nullspace = kwargs.get("near_nullspace", None)
options_prefix = kwargs.get("options_prefix", None)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return P, bcs, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix, pre_apply_bcs
return P, bcs, objective, solver_parameters, nullspace, nullspace_T, near_nullspace, options_prefix, pre_apply_bcs


def _extract_args(*args, **kwargs):
Comment thread
pbrubeck marked this conversation as resolved.
"Extraction of arguments for _solve_varproblem"

# Check for use of valid kwargs
valid_kwargs = ["bcs", "J", "Jp", "M",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M had been used before as the objective, however we were completely ignoring it

valid_kwargs = ["bcs", "J", "Jp", "objective",
"form_compiler_parameters", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"options_prefix", "appctx", "restrict", "pre_apply_bcs"]
Expand Down Expand Up @@ -314,9 +321,7 @@ def _extract_args(*args, **kwargs):
Jp = kwargs.get("Jp", None)

# Extract functional
M = kwargs.get("M", None)
if M is not None and not isinstance(M, ufl.Form):
raise RuntimeError("Expecting goal functional M to be a UFL Form")
objective = kwargs.get("objective", None)

nullspace = kwargs.get("nullspace", None)
nullspace_T = kwargs.get("transpose_nullspace", None)
Expand All @@ -328,7 +333,7 @@ def _extract_args(*args, **kwargs):
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
return eq, u, bcs, J, Jp, objective, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, restrict, pre_apply_bcs

Expand Down
34 changes: 34 additions & 0 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(self, problem,
self.pmatfree = pmatfree
self.F = problem.F
self.J = problem.J
self.E = problem.E

# For Jp to equal J, bc.Jp must equal bc.J for all EquationBC objects.
Jp_eq_J = problem.Jp is None and all(bc.Jp_eq_J for bc in problem.bcs)
Expand Down Expand Up @@ -261,6 +262,12 @@ def __init__(self, problem,

self.F -= problem.compute_bc_lifting(self.J, self._bc_residual)

self._assemble_objective = lambda *args, **kwargs: args
if self.E:
self._assemble_objective = get_assembler(self.E,
form_compiler_parameters=self.fcp,
).assemble

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=pre_apply_bcs,
Expand Down Expand Up @@ -342,6 +349,12 @@ def transfer_manager(self, manager):
raise ValueError("Must set transfer manager before first use.")
self._transfer_manager = manager

def set_objective(self, snes):
if self._problem.E:
snes.setObjective(self.form_objective)
else:
snes.setObjective(None)
Comment thread
stefanozampini marked this conversation as resolved.

def set_function(self, snes):
r"""Set the residual evaluation function"""
with self._F.dat.vec_wo as v:
Expand Down Expand Up @@ -449,6 +462,27 @@ def split(self, fields):
splits.append(self.reconstruct(new_problem, options_prefix=options_prefix))
return self._splits.setdefault(tuple(fields), splits)

@staticmethod
def form_objective(snes, X):
r"""Form the objective for this problem

:arg snes: a PETSc SNES object
:arg X: the current guess (a Vec)
"""
dm = snes.getDM()
ctx = dmhooks.get_appctx(dm)
# X may not be the same vector as the vec behind self._x, so
# copy guess in from X.
with ctx._x.dat.vec_wo as v:
X.copy(v)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? ctx._x is the solution according to Firedrake (but SNES has a different one) can we treat ctx._x as a buffer?

@stefanozampini stefanozampini Jun 13, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I copied that from the residual evaluation routine. At the end of solve, you have work.copy(u), with u = problem.u_restrict.dat.vec and work the Vec SNES is using for the solution. From the point of view of the solve, this is ok. But if your ctx._x is supposed to be immutable for a given Newton step, then overwriting it can be a problem, yes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can form_objective be called within a single Newton step? i.e. by the KSP? If so then I do not think that this is safe because for anything matrix-free we assume that ctx._x holds the current state.

I think you just need to save the state in ctx._x to a temporary buffer at entry to this function and restore it at exit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle it should not happen, but I agree with you better be safe. Make a suggestion in the diff and I will merge it in


if not ctx.pre_apply_bcs:
Comment thread
stefanozampini marked this conversation as resolved.
Outdated
# Apply DirichletBC on the solution
for bc in ctx._problem.dirichlet_bcs():
bc.apply(ctx._x)

return ctx._assemble_objective()

@staticmethod
def form_function(snes, X, F):
r"""Form the residual for this problem
Expand Down
19 changes: 16 additions & 3 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"NonlinearVariationalSolver"]


def check_pde_args(F, J, Jp):
def check_pde_args(F, J, Jp, E=None):
if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
if len(F.arguments()) != 1:
Expand All @@ -36,6 +36,11 @@ def check_pde_args(F, J, Jp):
raise TypeError("Provided preconditioner is a '%s', not a BaseForm or Slate Tensor" % type(Jp).__name__)
if Jp is not None and len(Jp.arguments()) != 2:
raise ValueError("Provided preconditioner is not a bilinear form")
if E is not None:
if not isinstance(E, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided objective is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
if len(E.arguments()) != 0:
raise ValueError("Provided objective is not a 0-form")


def is_form_consistent(is_linear, bcs):
Expand All @@ -52,6 +57,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
@NonlinearVariationalProblemMixin._ad_annotate_init
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
objective=None,
form_compiler_parameters=None,
is_linear=False, restrict=False):
r"""
Expand All @@ -62,6 +68,7 @@ def __init__(self, F, u, bcs=None, J=None,
:param Jp: a form used for preconditioning the linear system,
optional, if not supplied then the Jacobian itself
will be used.
:param objective: a form used for line-search or trust-region methods, optional
:param dict form_compiler_parameters: parameters to pass to the form
compiler (optional)
:is_linear: internally used to check if all domain/bc forms
Expand All @@ -82,6 +89,7 @@ def __init__(self, F, u, bcs=None, J=None,
self.J = J or ufl_expr.derivative(F, u)
self.F = F
self.Jp = Jp
Comment thread
pbrubeck marked this conversation as resolved.
self.E = objective
if isinstance(J, MatrixBase):
if bcs:
raise RuntimeError("It is not possible to apply or change boundary conditions to an already assembled Jacobian; pass any necessary boundary conditions to `assemble` when assembling the Jacobian.")
Expand Down Expand Up @@ -119,7 +127,7 @@ def __init__(self, F, u, bcs=None, J=None,
self.Jp_eq_J = Jp is None

# Argument checking
check_pde_args(self.F, self.J, self.Jp)
check_pde_args(self.F, self.J, self.Jp, E=self.E)

# Store form compiler parameters
self.form_compiler_parameters = form_compiler_parameters
Expand Down Expand Up @@ -304,6 +312,7 @@ def update_diffusivity(current_solution):
self._work = problem.u_restrict.dof_dset.layout_vec.duplicate()
self.snes.setDM(problem.dm)

ctx.set_objective(self.snes)
ctx.set_function(self.snes)
ctx.set_jacobian(self.snes)
ctx.set_nullspace(nullspace, problem.J.arguments()[0].function_space()._ises,
Expand Down Expand Up @@ -353,12 +362,13 @@ def solve(self, bounds=None):
``vinewtonssls`` or ``vinewtonrsls``.
"""
# Make sure the DM has this solver's callback functions
self._ctx.set_objective(self.snes)
self._ctx.set_function(self.snes)
self._ctx.set_jacobian(self.snes)

# Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve.
problem = self._problem
forms = (problem.F, problem.J, problem.Jp)
forms = (problem.F, problem.J, problem.Jp, problem.E)
coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None))
solution_dm = self.snes.getDM()
# Grab the unique DMs for this problem
Expand Down Expand Up @@ -412,6 +422,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem):

@PETSc.Log.EventDecorator()
def __init__(self, a, L, u, bcs=None, aP=None,
objective=None,
form_compiler_parameters=None,
constant_jacobian=False, restrict=False):
r"""
Expand All @@ -422,6 +433,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,
:param aP: an optional operator to assemble to precondition
the system (if not provided a preconditioner may be
computed from ``a``)
:param objective: a form used for line-search, optional
:param dict form_compiler_parameters: parameters to pass to the form
compiler (optional)
:param constant_jacobian: (optional) flag indicating that the
Expand All @@ -442,6 +454,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,
F = self.compute_bc_lifting(a, u, L=L)

super(LinearVariationalProblem, self).__init__(F, u, bcs=bcs, J=a, Jp=aP,
objective=objective,
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
form_compiler_parameters=form_compiler_parameters,
is_linear=True, restrict=restrict)
self._constant_jacobian = constant_jacobian
Expand Down
Loading
Loading