diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 8a25daf2..3fc78b05 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -1,4 +1,6 @@ +from abc import abstractmethod from enum import Enum +from functools import cached_property from numbers import Complex import numpy as np @@ -170,18 +172,24 @@ class RFOperation(Enum): class ReducedFunctionalMatBase: """ - PETSc.Mat Python context to apply the action of a pyadjoint.ReducedFunctional. + Base class for PETSc.Mat Python contexts for applying the action of a ReducedFunctional. - If V is the control space and U is the functional space, each action has the following map: - Jhat : V -> U - TLM : V -> U - Adjoint : U* -> V* - Hessian : V x U* -> V* | V -> V* + If V is the control space and U is the functional space, then the ReducedFunctional + Jhat and its methods map between the following spaces: + * Jhat : V -> U + * TLM : V -> U + * Adjoint : U* -> V* + * Hessian : V x U* -> V* | V -> V* + Child classes implement the matrix action for a particular method of Jhat. + + For the matrix action Ax=y the input x and output y will live in either + V, U, V*, or U* (e.g. for the tlm action x is in V and y is in U). - Child classes must implement: - - mult_impl - - multHermitian_impl + Child classes must implement the following (see docstrings for details): + - mult_impl, multHermitian_impl - update_adjoint + - x, y + - xinterface, yinterface Args: rf (ReducedFunctional): Defines the forward model. Used to compute Mat actions. @@ -189,7 +197,7 @@ class ReducedFunctionalMatBase: result of the action to PETSc. appctx (Optional[dict]): User provided context. always_update_tape (bool): Whether to force reevaluation of the forward model - every time `mult` is called. If needs_adjoint_update then this will also force + every time `mult` is called. If update_adjoint is True then this will also force the adjoint model to be reevaluated at every call to `mult`. needs_functional_interface: Whether to create a PETScVecInterface for rf.functional. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. @@ -217,6 +225,44 @@ def __init__(self, rf, *, self._shift = 0 self.always_update_tape = always_update_tape + @property + @abstractmethod + def x(self): + """An instance of (OverloadedType | list[OverloadedType]) that is suitable to + be the input to the matrix action. + + e.g. the tlm action is (V -> U) so x would be in V, the control space. + """ + pass + + @property + @abstractmethod + def y(self): + """An instance of (OverloadedType | list[OverloadedType]) that is suitable to + be the output to the matrix action. + + e.g. the tlm action is (V -> U) so y would be in U, the functional space. + """ + pass + + @property + @abstractmethod + def xinterface(self): + """A PETScVecInterface for x to transfer data to/from PETSc Vecs. + + Should be either self.control_interface or self.functional_interface. + """ + pass + + @property + @abstractmethod + def yinterface(self): + """A PETScVecInterface for y to transfer data to/from PETSc Vecs. + + Should be either self.control_interface or self.functional_interface. + """ + pass + @classmethod def update(cls, obj, x, A, P): ctx = A.getPythonContext() @@ -240,13 +286,13 @@ def mult(self, A, x, y): if self._shift != 0: y.axpy(self._shift, x) - def multHermitian(self, A, x, y): - self.yinterface.from_petsc(x, self.x) - out = self.multHermitian_impl(A, self.x) - self.xinterface.to_petsc(y, out) + def multHermitian(self, AT, y, x): + self.yinterface.from_petsc(y, self.y) + out = self.multHermitian_impl(AT, self.y) + self.xinterface.to_petsc(x, out) if self._shift != 0: - y.axpy(self._shift, x) + x.axpy(self._shift, y) def mult_impl(self, A, x): """ @@ -258,6 +304,9 @@ def mult_impl(self, A, x): A (PETSc.Mat): The Mat that this python context is attached to. x (Union[OverloadedType, list[OverloadedType]]): An element in either the control or functional space of the ReducedFunctional that this Mat will act on. + + Returns: + (Union[OverloadedType, list[OverloadedType]]): The result of the matrix action. """ raise NotImplementedError( "Must provide implementation of the action of this matrix on an OverloadedType") @@ -272,6 +321,9 @@ def multHermitian_impl(self, A, y): A (PETSc.Mat): The Mat that this python context is attached to. y (Union[OverloadedType, list[OverloadedType]]): An element in either the control or functional space of the ReducedFunctional that this Mat will act on. + + Returns: + (Union[OverloadedType, list[OverloadedType]]): The result of the Hermitian matrix action. """ raise NotImplementedError( "Must provide implementation of the Hermitian action of this matrix on an OverloadedType") @@ -310,9 +362,21 @@ def __init__(self, rf, *, apply_riesz=False, appctx=None, needs_functional_interface=False, always_update_tape=always_update_tape, comm=comm) - self.xinterface = self.control_interface - self.yinterface = self.control_interface - self.x = new_control_variable(rf) + @cached_property + def x(self): + return new_control_variable(self.rf) + + @cached_property + def y(self): + return new_control_variable(self.rf, dual=not self.apply_riesz) + + @property + def xinterface(self): + return self.control_interface + + @property + def yinterface(self): + return self.control_interface @classmethod def update_adjoint(self): @@ -353,9 +417,21 @@ def __init__(self, rf, *, apply_riesz=False, appctx=None, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) - self.xinterface = self.functional_interface - self.yinterface = self.control_interface - self.x = rf.functional._ad_copy() + @cached_property + def x(self): + return self.rf.functional._ad_init_zero(dual=True) + + @cached_property + def y(self): + return new_control_variable(self.rf, dual=not self.apply_riesz) + + @property + def xinterface(self): + return self.functional_interface + + @property + def yinterface(self): + return self.control_interface @classmethod def update_adjoint(self): @@ -390,12 +466,25 @@ class ReducedFunctionalTLMMat(ReducedFunctionalMatBase): """ def __init__(self, rf, *, appctx=None, always_update_tape=False, comm=None): + super().__init__(rf, appctx=appctx, needs_functional_interface=True, always_update_tape=always_update_tape, comm=comm) - self.xinterface = self.control_interface - self.yinterface = self.functional_interface - self.x = new_control_variable(rf) + @cached_property + def x(self): + return new_control_variable(self.rf) + + @cached_property + def y(self): + return self.rf.functional._ad_init_zero() + + @property + def xinterface(self): + return self.control_interface + + @property + def yinterface(self): + return self.functional_interface @classmethod def update_adjoint(self):