Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
add29d3
Add support for objective function in SNES
stefanozampini Jun 7, 2026
d55a3bb
Apply suggestions from code review
stefanozampini Jun 7, 2026
d41cc7e
add FAS support
pbrubeck Jun 7, 2026
bf6f6c7
fix tolerances
stefanozampini Jun 8, 2026
5835420
Apply suggestions from code review
pbrubeck Jun 8, 2026
a4c3dc3
solve interface
pbrubeck Jun 9, 2026
6654120
WIP
stefanozampini Jun 10, 2026
0ea74a0
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 10, 2026
ef82c67
wip
stefanozampini Jun 10, 2026
1cb9b75
Apply suggestion from @stefanozampini
stefanozampini Jun 10, 2026
c3d50f2
Remove objective from linear problem
pbrubeck Jun 10, 2026
b899636
test passes with HYPRE
stefanozampini Jun 10, 2026
78d169f
wip
stefanozampini Jun 10, 2026
dd0e73b
Move bcdofs to SNESContext, add temporary postsolve to zero bc compon…
stefanozampini Jun 11, 2026
fe5dce6
WIP TEST
stefanozampini Jun 11, 2026
e49c317
add allen-cahn test
stefanozampini Jun 12, 2026
96b63ef
tests also newtonls
stefanozampini Jun 12, 2026
8105f56
Move bcdofs to bcs.py
pbrubeck Jun 12, 2026
f94cb29
Update docs/source/solving-interface.rst
stefanozampini Jun 13, 2026
66483db
Update tests/firedrake/regression/test_snes_objective.py
stefanozampini Jun 13, 2026
64b9690
wip test
stefanozampini Jun 13, 2026
13895de
Merge branch 'stefanozampini/snes-objective' of github.com:firedrakep…
stefanozampini Jun 13, 2026
20adc3e
finalize tests
stefanozampini Jun 13, 2026
f0767f1
wip : add restrict failing tests
stefanozampini Jun 13, 2026
913190a
Fix restrict=True
pbrubeck Jun 13, 2026
ba19351
finalize tests
stefanozampini Jun 13, 2026
384e096
add restrict tests to allen cahn too
stefanozampini Jun 13, 2026
b98076d
wip
stefanozampini Jun 13, 2026
c03ec10
Cleanup
pbrubeck Jun 13, 2026
fc6c220
fixes
pbrubeck Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,16 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
coefficient_mapping = {}

bcs = [self(bc, self, coefficient_mapping=coefficient_mapping) for bc in problem.bcs]
E = self(problem.E, self, coefficient_mapping=coefficient_mapping)
F = self(problem.F, self, coefficient_mapping=coefficient_mapping)
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping)
u = coefficient_mapping[problem.u_restrict]

fine = problem
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp,
objective=E,
is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
fine._coarse = problem
return problem
Expand Down
34 changes: 34 additions & 0 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(self, problem,
self.pmatfree = pmatfree
self.F = problem.F
self.J = problem.J
self.E = problem.E

# For Jp to equal J, bc.Jp must equal bc.J for all EquationBC objects.
Jp_eq_J = problem.Jp is None and all(bc.Jp_eq_J for bc in problem.bcs)
Expand Down Expand Up @@ -261,6 +262,12 @@ def __init__(self, problem,

self.F -= problem.compute_bc_lifting(self.J, self._bc_residual)

self._assemble_objective = lambda *args, **kwargs: args
if self.E:
self._assemble_objective = get_assembler(self.E,
form_compiler_parameters=self.fcp,
).assemble

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=pre_apply_bcs,
Expand Down Expand Up @@ -342,6 +349,12 @@ def transfer_manager(self, manager):
raise ValueError("Must set transfer manager before first use.")
self._transfer_manager = manager

def set_objective(self, snes):
if self._problem.E:
snes.setObjective(self.form_objective)
else:
snes.setObjective(None)
Comment thread
stefanozampini marked this conversation as resolved.

def set_function(self, snes):
r"""Set the residual evaluation function"""
with self._F.dat.vec_wo as v:
Expand Down Expand Up @@ -449,6 +462,27 @@ def split(self, fields):
splits.append(self.reconstruct(new_problem, options_prefix=options_prefix))
return self._splits.setdefault(tuple(fields), splits)

@staticmethod
def form_objective(snes, X):
r"""Form the objective for this problem

:arg snes: a PETSc SNES object
:arg X: the current guess (a Vec)
"""
dm = snes.getDM()
ctx = dmhooks.get_appctx(dm)
# X may not be the same vector as the vec behind self._x, so
# copy guess in from X.
with ctx._x.dat.vec_wo as v:
X.copy(v)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? ctx._x is the solution according to Firedrake (but SNES has a different one) can we treat ctx._x as a buffer?

@stefanozampini stefanozampini Jun 13, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I copied that from the residual evaluation routine. At the end of solve, you have work.copy(u), with u = problem.u_restrict.dat.vec and work the Vec SNES is using for the solution. From the point of view of the solve, this is ok. But if your ctx._x is supposed to be immutable for a given Newton step, then overwriting it can be a problem, yes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can form_objective be called within a single Newton step? i.e. by the KSP? If so then I do not think that this is safe because for anything matrix-free we assume that ctx._x holds the current state.

I think you just need to save the state in ctx._x to a temporary buffer at entry to this function and restore it at exit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle it should not happen, but I agree with you better be safe. Make a suggestion in the diff and I will merge it in


if ctx.pre_apply_bcs:
# Apply DirichletBC on the solution
for bc in ctx._problem.dirichlet_bcs():
bc.apply(ctx._x)

Comment thread
pbrubeck marked this conversation as resolved.
Outdated
return ctx._assemble_objective()

@staticmethod
def form_function(snes, X, F):
r"""Form the residual for this problem
Expand Down
16 changes: 13 additions & 3 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"NonlinearVariationalSolver"]


def check_pde_args(F, J, Jp):
def check_pde_args(F, J, Jp, E=None):
if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
if len(F.arguments()) != 1:
Expand All @@ -36,6 +36,11 @@ def check_pde_args(F, J, Jp):
raise TypeError("Provided preconditioner is a '%s', not a BaseForm or Slate Tensor" % type(Jp).__name__)
if Jp is not None and len(Jp.arguments()) != 2:
raise ValueError("Provided preconditioner is not a bilinear form")
if E is not None:
if not isinstance(E, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided objective is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
if len(E.arguments()) != 0:
raise ValueError("Provided objective is not a 0-form")


def is_form_consistent(is_linear, bcs):
Expand All @@ -52,6 +57,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
@NonlinearVariationalProblemMixin._ad_annotate_init
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
objective=None,
form_compiler_parameters=None,
is_linear=False, restrict=False):
r"""
Expand All @@ -62,6 +68,7 @@ def __init__(self, F, u, bcs=None, J=None,
:param Jp: a form used for preconditioning the linear system,
optional, if not supplied then the Jacobian itself
will be used.
:param objective: a form used for line-search, optional
Comment thread
stefanozampini marked this conversation as resolved.
Outdated
:param dict form_compiler_parameters: parameters to pass to the form
compiler (optional)
:is_linear: internally used to check if all domain/bc forms
Expand All @@ -82,6 +89,7 @@ def __init__(self, F, u, bcs=None, J=None,
self.J = J or ufl_expr.derivative(F, u)
self.F = F
self.Jp = Jp
Comment thread
pbrubeck marked this conversation as resolved.
self.E = objective
if isinstance(J, MatrixBase):
if bcs:
raise RuntimeError("It is not possible to apply or change boundary conditions to an already assembled Jacobian; pass any necessary boundary conditions to `assemble` when assembling the Jacobian.")
Expand Down Expand Up @@ -119,7 +127,7 @@ def __init__(self, F, u, bcs=None, J=None,
self.Jp_eq_J = Jp is None

# Argument checking
check_pde_args(self.F, self.J, self.Jp)
check_pde_args(self.F, self.J, self.Jp, E=self.E)

# Store form compiler parameters
self.form_compiler_parameters = form_compiler_parameters
Expand Down Expand Up @@ -304,6 +312,7 @@ def update_diffusivity(current_solution):
self._work = problem.u_restrict.dof_dset.layout_vec.duplicate()
self.snes.setDM(problem.dm)

ctx.set_objective(self.snes)
ctx.set_function(self.snes)
ctx.set_jacobian(self.snes)
ctx.set_nullspace(nullspace, problem.J.arguments()[0].function_space()._ises,
Expand Down Expand Up @@ -353,12 +362,13 @@ def solve(self, bounds=None):
``vinewtonssls`` or ``vinewtonrsls``.
"""
# Make sure the DM has this solver's callback functions
self._ctx.set_objective(self.snes)
self._ctx.set_function(self.snes)
self._ctx.set_jacobian(self.snes)

# Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve.
problem = self._problem
forms = (problem.F, problem.J, problem.Jp)
forms = (problem.F, problem.J, problem.Jp, problem.E)
coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None))
solution_dm = self.snes.getDM()
# Grab the unique DMs for this problem
Expand Down
56 changes: 56 additions & 0 deletions tests/firedrake/regression/test_snes_objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from firedrake import *
import pytest


newtontr_params = {
"snes_atol": 1E-8,
"snes_rtol": 1E-8,
"snes_monitor": "::ascii_info_detail",
"snes_type": "newtontr",
"ksp_type": "cg",
"pc_type": "none",
}


fas_newtontr_params = {
"snes_monitor": "::ascii_info_detail",
"snes_max_it": 1,
"snes_type": "fas",
"snes_fas_type": "kaskade",
"fas_levels": newtontr_params,
"fas_coarse": newtontr_params,
}


@pytest.mark.parametrize("refine", (0, 1))
def test_bratu_energy(refine):
base = UnitIntervalMesh(10)
mh = MeshHierarchy(base, refine)
mesh = mh[-1]
V = FunctionSpace(mesh, "CG", 3)

u = Function(V)
v = TestFunction(V)
sol1 = Function(V)
sol2 = Function(V)

lmbda = Constant(2)

E = 0.5 * inner(grad(u), grad(u))*dx + exp(lmbda*u)*dx
F = inner(grad(u), grad(v))*dx + lmbda*inner(exp(lmbda*u), v)*dx
bcs = DirichletBC(V, 0, "on_boundary")

sp = newtontr_params if refine == 0 else fas_newtontr_params
problem = NonlinearVariationalProblem(F, u, bcs, objective=E)
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
solver.solve()
sol1.assign(u)

u.assign(0)
sp = {"snes_monitor": "::ascii_info_detail"}
problem = NonlinearVariationalProblem(F, u, bcs)
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
solver.solve()
sol2.assign(u)

assert norm(sol1 - sol2) < 1.e-8
Loading