diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index f45563ee6f..9f0670f1d6 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -563,10 +563,28 @@ def _interpolate_from_quadrature(self) -> Interpolate: elif self.ufl_interpolate.is_adjoint: return interpolate(TestFunction(self.target_space), self.dual_arg) + def _bc_mask(self, space: WithGeometry, bcs: Iterable[DirichletBC]) -> Function: + """Return a 0/1 mask over `space` which is zero at boundary condition nodes + """ + space = space.dual() if is_dual(space) else space + f = Function(space).assign(1.0) + for bc in bcs: + if bc.function_space() == space: + bc.zero(f) + return f + + def apply_bcs(self, mat: PETSc.Mat, bcs: Iterable[DirichletBC]) -> PETSc.Mat: + """Zero the rows and columns of `mat` associated with boundary condition nodes. + """ + row_arg, col_arg = self.ufl_interpolate.arguments() + row_mask = self._bc_mask(row_arg.function_space(), bcs) + col_mask = self._bc_mask(col_arg.function_space(), bcs) + with row_mask.dat.vec_ro as r, col_mask.dat.vec_ro as c: + mat.diagonalScale(r, c) + return mat + def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None): from firedrake.assemble import assemble - if bcs: - raise NotImplementedError("bcs not implemented for cross-mesh interpolation.") mat_type = mat_type or "aij" if self.into_quadrature_space: @@ -593,12 +611,13 @@ def callable() -> PETSc.Mat: source_space = self.operand.function_space() if self.ufl_interpolate.is_adjoint: I = Matrix(interpolate(TestFunction(source_space), self.target_space), res) - return assemble(action(I, self._interpolate_from_quadrature)).petscmat + res = assemble(action(I, self._interpolate_from_quadrature)).petscmat else: I = Matrix(interpolate(TrialFunction(source_space), self.target_space), res) - return assemble(action(self._interpolate_from_quadrature, I)).petscmat - else: - return res + res = assemble(action(self._interpolate_from_quadrature, I)).petscmat + if bcs: + res = self.apply_bcs(res, bcs) + return res elif self.ufl_interpolate.is_adjoint: assert self.rank == 1 diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index c41f37dbab..9910de296d 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -756,3 +756,38 @@ def test_mixed_interpolator_cross_mesh(): assert isinstance(interp_ij, Interpolate) res_block = assemble(interpolate(TrialFunction(W.sub(j)), U.sub(i), allow_missing_dofs=True)) assert np.allclose(res.petscmat.getNestSubMatrix(i, j)[:, :], res_block.petscmat[:, :]) + + +@pytest.mark.parallel([1, 3]) +@pytest.mark.parametrize("variant", ["source", "target", "both"]) +def test_interpolate_cross_mesh_bcs(variant): + source_mesh = UnitSquareMesh(2, 2) + target_mesh = UnitSquareMesh(3, 3) + U = FunctionSpace(source_mesh, "CG", 1) + V = FunctionSpace(target_mesh, "CG", 1) + + x, y = SpatialCoordinate(source_mesh) + f = Function(U).interpolate(1 + x + 2*y) + + source_bcs = [DirichletBC(U, 0, 1), DirichletBC(U, 0, 3)] + target_bcs = [DirichletBC(V, 0, 2), DirichletBC(V, 0, 4)] + + if variant == "both": + bcs = source_bcs + target_bcs + elif variant == "source": + bcs = source_bcs + elif variant == "target": + bcs = target_bcs + + interp = assemble(interpolate(TrialFunction(U), V), bcs=bcs) + result = assemble(interp @ f) + + if variant in ("source", "both"): + for bc in source_bcs: + bc.zero(f) + expected = assemble(interpolate(f, V)) + if variant in ("target", "both"): + for bc in target_bcs: + bc.zero(expected) + + assert np.allclose(result.dat.data_ro, expected.dat.data_ro)