-
Notifications
You must be signed in to change notification settings - Fork 191
Add support for objective function in SNES #5155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
add29d3
d55a3bb
d41cc7e
bf6f6c7
5835420
a4c3dc3
6654120
0ea74a0
ef82c67
1cb9b75
c3d50f2
b899636
78d169f
dd0e73b
fe5dce6
e49c317
96b63ef
8105f56
f94cb29
66483db
64b9690
13895de
20adc3e
f0767f1
913190a
ba19351
384e096
b98076d
c03ec10
fc6c220
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| from pyop2 import op2 | ||
| from firedrake import dmhooks | ||
| from firedrake.bcs import bcdofs | ||
| from firedrake.function import Function | ||
| from firedrake.cofunction import Cofunction | ||
| from firedrake.matrix import MatrixBase | ||
|
|
@@ -231,6 +232,7 @@ def __init__(self, problem, | |
| self.pmatfree = pmatfree | ||
| self.F = problem.F | ||
| self.J = problem.J | ||
| self.E = problem.E | ||
|
|
||
| # For Jp to equal J, bc.Jp must equal bc.J for all EquationBC objects. | ||
| Jp_eq_J = problem.Jp is None and all(bc.Jp_eq_J for bc in problem.bcs) | ||
|
|
@@ -261,6 +263,12 @@ def __init__(self, problem, | |
|
|
||
| self.F -= problem.compute_bc_lifting(self.J, self._bc_residual) | ||
|
|
||
| self._assemble_objective = lambda *args, **kwargs: args | ||
| if self.E: | ||
| self._assemble_objective = get_assembler(self.E, | ||
| form_compiler_parameters=self.fcp, | ||
| ).assemble | ||
|
|
||
| self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F, | ||
| form_compiler_parameters=self.fcp, | ||
| zero_bc_nodes=pre_apply_bcs, | ||
|
|
@@ -277,6 +285,25 @@ def __init__(self, problem, | |
| self._coefficient_mapping = None | ||
| self._transfer_manager = transfer_manager | ||
|
|
||
| @cached_property | ||
| def bc_iset(self): | ||
| if self._problem.restrict: | ||
| return None | ||
| bcs = tuple(self._problem.dirichlet_bcs()) | ||
| V = self._x.function_space() | ||
| if len(bcs) > 0: | ||
| bc_nodes = numpy.unique( | ||
| numpy.concatenate([bcdofs(bc, ghost=False) for bc in bcs], | ||
| dtype=PETSc.IntType) | ||
| ) | ||
| bc_nodes = V.dof_dset.lgmap.apply(bc_nodes) | ||
| else: | ||
| bc_nodes = numpy.empty(0, dtype=PETSc.IntType) | ||
|
|
||
| bc_is = PETSc.IS().createGeneral(bc_nodes, comm=V.comm) | ||
| bc_is.sort() | ||
| return bc_is | ||
|
|
||
| def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs): | ||
| """Reconstruct this _SNESContext instance with new arguments.""" | ||
| problem = problem or self._problem | ||
|
|
@@ -342,6 +369,12 @@ def transfer_manager(self, manager): | |
| raise ValueError("Must set transfer manager before first use.") | ||
| self._transfer_manager = manager | ||
|
|
||
| def set_objective(self, snes): | ||
| if self._problem.E: | ||
| snes.setObjective(self.form_objective) | ||
| else: | ||
| snes.setObjective(None) | ||
|
stefanozampini marked this conversation as resolved.
|
||
|
|
||
| def set_function(self, snes): | ||
| r"""Set the residual evaluation function""" | ||
| with self._F.dat.vec_wo as v: | ||
|
|
@@ -360,6 +393,9 @@ def set_nullspace(self, nullspace, ises=None, transpose=False, near=False): | |
| if ises is not None: | ||
| nullspace._apply(ises, transpose=transpose, near=near) | ||
|
|
||
| def set_ksp_postsolve(self, snes): | ||
| snes.getKSP().setPostSolve(self.ksp_postsolve) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this overwrite an existing user post-solve callback? Or will they both be called in that case?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a different one. The existing ones are SNES related (pre/post Jacobian/Function)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not the ones passed to the nlvs, the ones on the actual ksp. A user might have set one, or a particular ksp type might define one itself
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a temporary fix to firedrake's model of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is set as the default when firedrake creates the SNES. Users can do whatever they want afterward. Add their own, or change the KSP. And no, KSPs don't have their own since each KSP implements its own Solve; these are user-specific callbacks As I said, this is temporary and can be removed once we understand what is going wrong with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What do you mean. With restrict you do not have BC dofs at all in the system, so it should not cause problems either to KSP or SNES, since you will just solve the proper linear system only approximatively
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For option 2, we would change the default to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The SNES issue is unavoidable and inherent to the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened #5177 to continue with this discussion |
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def split(self, fields): | ||
| from firedrake import replace, as_vector, split, zero | ||
|
|
@@ -449,6 +485,27 @@ def split(self, fields): | |
| splits.append(self.reconstruct(new_problem, options_prefix=options_prefix)) | ||
| return self._splits.setdefault(tuple(fields), splits) | ||
|
|
||
| @staticmethod | ||
| def form_objective(snes, X): | ||
| r"""Form the objective for this problem | ||
|
|
||
| :arg snes: a PETSc SNES object | ||
| :arg X: the current guess (a Vec) | ||
| """ | ||
| dm = snes.getDM() | ||
| ctx = dmhooks.get_appctx(dm) | ||
| # X may not be the same vector as the vec behind self._x, so | ||
| # copy guess in from X. | ||
| with ctx._x.dat.vec_wo as v: | ||
| X.copy(v) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this safe?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know, I copied that from the residual evaluation routine. At the end of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I think you just need to save the state in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In principle it should not happen, but I agree with you better be safe. Make a suggestion in the diff and I will merge it in |
||
|
|
||
| # Apply DirichletBC on the solution | ||
| if not ctx._problem.restrict: | ||
| for bc in ctx._problem.dirichlet_bcs(): | ||
| bc.apply(ctx._x) | ||
|
|
||
| return ctx._assemble_objective() | ||
|
|
||
| @staticmethod | ||
| def form_function(snes, X, F): | ||
| r"""Form the residual for this problem | ||
|
|
@@ -547,6 +604,13 @@ def compute_operators(ksp, J, P): | |
| assert P.handle == ctx._pjac.petscmat.handle | ||
| ctx._assemble_pjac(ctx._pjac) | ||
|
|
||
| @staticmethod | ||
| def ksp_postsolve(ksp, rhs, sol): | ||
| dm = ksp.getDM() | ||
| ctx = dmhooks.get_appctx(dm) | ||
| if ctx.pre_apply_bcs and not ctx._problem.restrict: | ||
| sol.isset(ctx.bc_iset, 0) | ||
|
|
||
| @cached_property | ||
| def _assembler_jac(self): | ||
| from firedrake.assemble import get_assembler | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mhad been used before as the objective, however we were completely ignoring it