-
Notifications
You must be signed in to change notification settings - Fork 191
Add support for objective function in SNES #5155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
add29d3
d55a3bb
d41cc7e
bf6f6c7
5835420
a4c3dc3
6654120
0ea74a0
ef82c67
1cb9b75
c3d50f2
b899636
78d169f
dd0e73b
fe5dce6
e49c317
96b63ef
8105f56
f94cb29
66483db
64b9690
13895de
20adc3e
f0767f1
913190a
ba19351
384e096
b98076d
c03ec10
fc6c220
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -149,7 +149,7 @@ def _solve_varproblem(*args, **kwargs): | |
| "Solve variational problem a == L or F == 0" | ||
|
|
||
| # Extract arguments | ||
| eq, u, bcs, J, Jp, M, form_compiler_parameters, \ | ||
| eq, u, bcs, J, Jp, objective, form_compiler_parameters, \ | ||
| solver_parameters, nullspace, nullspace_T, \ | ||
| near_nullspace, \ | ||
| options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs) | ||
|
|
@@ -167,6 +167,8 @@ def _solve_varproblem(*args, **kwargs): | |
| raise TypeError(f"Equation LHS must be a ufl.BaseForm, not a {type(eq.lhs).__name__}") | ||
|
|
||
| if len(eq.lhs.arguments()) == 2: | ||
| if objective is not None: | ||
| raise ValueError("The objective functional only makes sense for nonlinear problems.") | ||
| # Create linear variational problem | ||
| problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp, | ||
| form_compiler_parameters=form_compiler_parameters, | ||
|
|
@@ -177,6 +179,7 @@ def _solve_varproblem(*args, **kwargs): | |
| if eq.rhs != 0: | ||
| raise ValueError(f"RHS of nonlinear Equation must be `0`, not {eq.rhs}") | ||
| problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp, | ||
| objective=objective, | ||
| form_compiler_parameters=form_compiler_parameters, | ||
| restrict=restrict) | ||
| create_solver = vs.NonlinearVariationalSolver | ||
|
|
@@ -279,7 +282,7 @@ def _extract_args(*args, **kwargs): | |
| "Extraction of arguments for _solve_varproblem" | ||
|
|
||
| # Check for use of valid kwargs | ||
| valid_kwargs = ["bcs", "J", "Jp", "M", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| valid_kwargs = ["bcs", "J", "Jp", "objective", | ||
| "form_compiler_parameters", "solver_parameters", | ||
| "nullspace", "transpose_nullspace", "near_nullspace", | ||
| "options_prefix", "appctx", "restrict", "pre_apply_bcs"] | ||
|
|
@@ -314,9 +317,7 @@ def _extract_args(*args, **kwargs): | |
| Jp = kwargs.get("Jp", None) | ||
|
|
||
| # Extract functional | ||
| M = kwargs.get("M", None) | ||
| if M is not None and not isinstance(M, ufl.Form): | ||
| raise RuntimeError("Expecting goal functional M to be a UFL Form") | ||
| objective = kwargs.get("objective", None) | ||
|
|
||
| nullspace = kwargs.get("nullspace", None) | ||
| nullspace_T = kwargs.get("transpose_nullspace", None) | ||
|
|
@@ -328,7 +329,7 @@ def _extract_args(*args, **kwargs): | |
| restrict = kwargs.get("restrict", False) | ||
| pre_apply_bcs = kwargs.get("pre_apply_bcs", True) | ||
|
|
||
| return eq, u, bcs, J, Jp, M, form_compiler_parameters, \ | ||
| return eq, u, bcs, J, Jp, objective, form_compiler_parameters, \ | ||
| solver_parameters, nullspace, nullspace_T, near_nullspace, \ | ||
| options_prefix, restrict, pre_apply_bcs | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||||||||||||||
|
|
||||||||||||||||
| from pyop2 import op2 | ||||||||||||||||
| from firedrake import dmhooks | ||||||||||||||||
| from firedrake.bcs import bcdofs | ||||||||||||||||
| from firedrake.function import Function | ||||||||||||||||
| from firedrake.cofunction import Cofunction | ||||||||||||||||
| from firedrake.matrix import MatrixBase | ||||||||||||||||
|
|
@@ -231,6 +232,7 @@ def __init__(self, problem, | |||||||||||||||
| self.pmatfree = pmatfree | ||||||||||||||||
| self.F = problem.F | ||||||||||||||||
| self.J = problem.J | ||||||||||||||||
| self.E = problem.E | ||||||||||||||||
|
|
||||||||||||||||
| # For Jp to equal J, bc.Jp must equal bc.J for all EquationBC objects. | ||||||||||||||||
| Jp_eq_J = problem.Jp is None and all(bc.Jp_eq_J for bc in problem.bcs) | ||||||||||||||||
|
|
@@ -261,6 +263,12 @@ def __init__(self, problem, | |||||||||||||||
|
|
||||||||||||||||
| self.F -= problem.compute_bc_lifting(self.J, self._bc_residual) | ||||||||||||||||
|
|
||||||||||||||||
| self._assemble_objective = lambda *args, **kwargs: args | ||||||||||||||||
| if self.E: | ||||||||||||||||
| self._assemble_objective = get_assembler(self.E, | ||||||||||||||||
| form_compiler_parameters=self.fcp, | ||||||||||||||||
| ).assemble | ||||||||||||||||
|
|
||||||||||||||||
| self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F, | ||||||||||||||||
| form_compiler_parameters=self.fcp, | ||||||||||||||||
| zero_bc_nodes=pre_apply_bcs, | ||||||||||||||||
|
|
@@ -277,6 +285,18 @@ def __init__(self, problem, | |||||||||||||||
| self._coefficient_mapping = None | ||||||||||||||||
| self._transfer_manager = transfer_manager | ||||||||||||||||
|
|
||||||||||||||||
| @cached_property | ||||||||||||||||
| def bc_iset(self): | ||||||||||||||||
| if self._problem.restrict: | ||||||||||||||||
| return None | ||||||||||||||||
| bcs = self._problem.dirichlet_bcs() | ||||||||||||||||
| V = self._x.function_space() | ||||||||||||||||
| bc_nodes = numpy.unique(numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs], dtype=PETSc.IntType)) | ||||||||||||||||
| bc_nodes = V.dof_dset.lgmap.apply(bc_nodes) | ||||||||||||||||
| bc_is = PETSc.IS().createGeneral(bc_nodes, comm=V.comm) | ||||||||||||||||
| bc_is.sort() | ||||||||||||||||
| return bc_is | ||||||||||||||||
|
|
||||||||||||||||
| def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs): | ||||||||||||||||
| """Reconstruct this _SNESContext instance with new arguments.""" | ||||||||||||||||
| problem = problem or self._problem | ||||||||||||||||
|
|
@@ -342,6 +362,12 @@ def transfer_manager(self, manager): | |||||||||||||||
| raise ValueError("Must set transfer manager before first use.") | ||||||||||||||||
| self._transfer_manager = manager | ||||||||||||||||
|
|
||||||||||||||||
| def set_objective(self, snes): | ||||||||||||||||
| if self._problem.E: | ||||||||||||||||
| snes.setObjective(self.form_objective) | ||||||||||||||||
| else: | ||||||||||||||||
| snes.setObjective(None) | ||||||||||||||||
|
stefanozampini marked this conversation as resolved.
|
||||||||||||||||
|
|
||||||||||||||||
| def set_function(self, snes): | ||||||||||||||||
| r"""Set the residual evaluation function""" | ||||||||||||||||
| with self._F.dat.vec_wo as v: | ||||||||||||||||
|
|
@@ -360,6 +386,9 @@ def set_nullspace(self, nullspace, ises=None, transpose=False, near=False): | |||||||||||||||
| if ises is not None: | ||||||||||||||||
| nullspace._apply(ises, transpose=transpose, near=near) | ||||||||||||||||
|
|
||||||||||||||||
| def set_ksp_postsolve(self, snes): | ||||||||||||||||
| snes.getKSP().setPostSolve(self.ksp_postsolve) | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this overwrite an existing user post-solve callback? Or will they both be called in that case?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a different one. The existing ones are SNES related (pre/post Jacobian/Function)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not the ones passed to the nlvs, the ones on the actual ksp. A user might have set one, or a particular ksp type might define one itself
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a temporary fix to firedrake's model of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is set as the default when firedrake creates the SNES. Users can do whatever they want afterward. Add their own, or change the KSP. And no, KSPs don't have their own since each KSP implements its own Solve; these are user-specific callbacks As I said, this is temporary and can be removed once we understand what is going wrong with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What do you mean. With restrict you do not have BC dofs at all in the system, so it should not cause problems either to KSP or SNES, since you will just solve the proper linear system only approximatively
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For option 2, we would change the default to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The SNES issue is unavoidable and inherent to the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened #5177 to continue with this discussion |
||||||||||||||||
|
|
||||||||||||||||
| @PETSc.Log.EventDecorator() | ||||||||||||||||
| def split(self, fields): | ||||||||||||||||
| from firedrake import replace, as_vector, split, zero | ||||||||||||||||
|
|
@@ -449,6 +478,26 @@ def split(self, fields): | |||||||||||||||
| splits.append(self.reconstruct(new_problem, options_prefix=options_prefix)) | ||||||||||||||||
| return self._splits.setdefault(tuple(fields), splits) | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| def form_objective(snes, X): | ||||||||||||||||
| r"""Form the objective for this problem | ||||||||||||||||
|
|
||||||||||||||||
| :arg snes: a PETSc SNES object | ||||||||||||||||
| :arg X: the current guess (a Vec) | ||||||||||||||||
| """ | ||||||||||||||||
| dm = snes.getDM() | ||||||||||||||||
| ctx = dmhooks.get_appctx(dm) | ||||||||||||||||
| # X may not be the same vector as the vec behind self._x, so | ||||||||||||||||
| # copy guess in from X. | ||||||||||||||||
| with ctx._x.dat.vec_wo as v: | ||||||||||||||||
| X.copy(v) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this safe?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know, I copied that from the residual evaluation routine. At the end of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I think you just need to save the state in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In principle it should not happen, but I agree with you better be safe. Make a suggestion in the diff and I will merge it in |
||||||||||||||||
|
|
||||||||||||||||
| # Apply DirichletBC on the solution | ||||||||||||||||
| for bc in ctx._problem.dirichlet_bcs(): | ||||||||||||||||
| bc.apply(ctx._x) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think what you are proposing is correct, but I also don't think what I have in the current PR is 100% correct. In all my experience with nonlinear solvers, you always inject the correct BCs whenever you evaluate objective functions or residuals. Then, in case of residuals, you modify afterward the part of the vector that is associated with bc, because there you solve a different set of equations, i.e. As a byproduct, if you do Newton, the boundary conditions will be exact after the first full step (that may not be the first Newton step if you use linesearch). See e.g. how I do it in MFEM https://github.com/mfem/mfem/blob/master/linalg/petsc.cpp#L4996 If ctx._x is supposed to be immutable during each phases of a given SNES step, then you need to rethink the firedrake logic a bit |
||||||||||||||||
|
|
||||||||||||||||
| return ctx._assemble_objective() | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| def form_function(snes, X, F): | ||||||||||||||||
| r"""Form the residual for this problem | ||||||||||||||||
|
|
@@ -547,6 +596,13 @@ def compute_operators(ksp, J, P): | |||||||||||||||
| assert P.handle == ctx._pjac.petscmat.handle | ||||||||||||||||
| ctx._assemble_pjac(ctx._pjac) | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| def ksp_postsolve(ksp, rhs, sol): | ||||||||||||||||
| dm = ksp.getDM() | ||||||||||||||||
| ctx = dmhooks.get_appctx(dm) | ||||||||||||||||
| if ctx.pre_apply_bcs and not ctx._problem.restrict: | ||||||||||||||||
| sol.isset(ctx.bc_iset, 0) | ||||||||||||||||
|
|
||||||||||||||||
| @cached_property | ||||||||||||||||
| def _assembler_jac(self): | ||||||||||||||||
| from firedrake.assemble import get_assembler | ||||||||||||||||
|
|
||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.