From 3afee985dc6f834f44a6228c576fa05f3fd38a4f Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Tue, 18 Nov 2025 16:37:30 -0700 Subject: [PATCH 1/3] fix logic error when zeros_outside was True --- discretize/_extensions/tree_ext.pyx | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/discretize/_extensions/tree_ext.pyx b/discretize/_extensions/tree_ext.pyx index dd1f75b1a..887db1e67 100644 --- a/discretize/_extensions/tree_ext.pyx +++ b/discretize/_extensions/tree_ext.pyx @@ -5813,26 +5813,33 @@ cdef class _TreeMesh: cell = self.tree.containing_cell(x, y, z) row_inds = indices[indptr[i]:indptr[i+1]] row_data = data[indptr[i]:indptr[i+1]] + was_outside = False if zeros_out: if x < cell.points[0].location[0]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif x > cell.points[3].location[0]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y < cell.points[0].location[1]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y > cell.points[3].location[1]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z < cell.points[0].location[2]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z > cell.points[7].location[2]+eps: row_data[:] = 0.0 row_inds[:] = 0 - else: + was_outside = True + if not was_outside: # look + dir and - dir away if ( locations[i, dir] < cell.location[dir] @@ -5964,26 +5971,33 @@ cdef class _TreeMesh: cell = self.tree.containing_cell(x, y, z) row_inds = indices[indptr[i]:indptr[i+1]] row_data = data[indptr[i]:indptr[i+1]] + was_outside = False if zeros_out: if x < cell.points[0].location[0]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif x > cell.points[3].location[0]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y < cell.points[0].location[1]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y > cell.points[3].location[1]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z < cell.points[0].location[2]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z > cell.points[7].location[2]+eps: row_data[:] = 0.0 row_inds[:] = 0 - else: + was_outside = True + if not was_outside: # Find containing cells # Decide order to search based on which face it is closest to if dim == 3: @@ -6219,26 +6233,33 @@ cdef class _TreeMesh: cell = self.tree.containing_cell(x, y, z) row_inds = indices[indptr[i]:indptr[i + 1]] row_data = data[indptr[i]:indptr[i + 1]] + was_outside = False if zeros_out: if x < cell.points[0].location[0]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif x > cell.points[3].location[0]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y < cell.points[0].location[1]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif y > cell.points[3].location[1]+eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z < cell.points[0].location[2]-eps: row_data[:] = 0.0 row_inds[:] = 0 + was_outside = True elif dim == 3 and z > cell.points[7].location[2]+eps: row_data[:] = 0.0 row_inds[:] = 0 - else: + was_outside = True + if not was_outside: # decide order to search based on distance to each faces # if ( From 503af5e4a1d515b562ce655461e41944bb8af888 Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Tue, 2 Dec 2025 10:32:24 -0700 Subject: [PATCH 2/3] refactor interpolation tests --- tests/tree/test_tree.py | 104 +------- tests/tree/test_tree_interpolation.py | 370 +++++++++++++------------- 2 files changed, 192 insertions(+), 282 deletions(-) diff --git a/tests/tree/test_tree.py b/tests/tree/test_tree.py index 2a27fb19d..d9d50556f 100644 --- a/tests/tree/test_tree.py +++ b/tests/tree/test_tree.py @@ -438,99 +438,6 @@ def test_cell_bounds(self, mesh): np.testing.assert_equal(cell_bounds, cell_bounds_slow) -class Test2DInterpolation(unittest.TestCase): - def setUp(self): - def topo(x): - return np.sin(x * (2.0 * np.pi)) * 0.3 + 0.5 - - def function(cell): - r = cell.center - np.array([0.5] * len(cell.center)) - dist1 = np.sqrt(r.dot(r)) - 0.08 - dist2 = np.abs(cell.center[-1] - topo(cell.center[0])) - - dist = min([dist1, dist2]) - # if dist < 0.05: - # return 5 - if dist < 0.05: - return 6 - if dist < 0.2: - return 5 - if dist < 0.3: - return 4 - if dist < 1.0: - return 3 - else: - return 0 - - M = discretize.TreeMesh([64, 64], levels=6) - M.refine(function) - self.M = M - - def test_fx(self): - r = rng.random(self.M.nFx) - P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx") - self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL) - - def test_fy(self): - r = rng.random(self.M.nFy) - P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy") - self.assertLess(np.abs(P[:, self.M.nFx :] * r - r).max(), TOL) - - -class Test3DInterpolation(unittest.TestCase): - def setUp(self): - def function(cell): - r = cell.center - np.array([0.5] * len(cell.center)) - dist = np.sqrt(r.dot(r)) - if dist < 0.2: - return 4 - if dist < 0.3: - return 3 - if dist < 1.0: - return 2 - else: - return 0 - - M = discretize.TreeMesh([16, 16, 16], levels=4) - M.refine(function) - # M.plot_grid(show_it=True) - self.M = M - - def test_Fx(self): - r = rng.random(self.M.nFx) - P = self.M.get_interpolation_matrix(self.M.gridFx, "Fx") - self.assertLess(np.abs(P[:, : self.M.nFx] * r - r).max(), TOL) - - def test_Fy(self): - r = rng.random(self.M.nFy) - P = self.M.get_interpolation_matrix(self.M.gridFy, "Fy") - self.assertLess( - np.abs(P[:, self.M.nFx : (self.M.nFx + self.M.nFy)] * r - r).max(), TOL - ) - - def test_Fz(self): - r = rng.random(self.M.nFz) - P = self.M.get_interpolation_matrix(self.M.gridFz, "Fz") - self.assertLess(np.abs(P[:, (self.M.nFx + self.M.nFy) :] * r - r).max(), TOL) - - def test_Ex(self): - r = rng.random(self.M.nEx) - P = self.M.get_interpolation_matrix(self.M.gridEx, "Ex") - self.assertLess(np.abs(P[:, : self.M.nEx] * r - r).max(), TOL) - - def test_Ey(self): - r = rng.random(self.M.nEy) - P = self.M.get_interpolation_matrix(self.M.gridEy, "Ey") - self.assertLess( - np.abs(P[:, self.M.nEx : (self.M.nEx + self.M.nEy)] * r - r).max(), TOL - ) - - def test_Ez(self): - r = rng.random(self.M.nEz) - P = self.M.get_interpolation_matrix(self.M.gridEz, "Ez") - self.assertLess(np.abs(P[:, (self.M.nEx + self.M.nEy) :] * r - r).max(), TOL) - - class TestWrapAroundLevels(unittest.TestCase): def test_refine_func(self): mesh1 = discretize.TreeMesh((16, 16, 16)) @@ -649,5 +556,16 @@ def test_repr_html(self, mesh, finalize): assert len(output) != 0 +@pytest.mark.parametrize("attr", ["average_edge_to_face"]) +def test_caching(attr): + mesh = discretize.TreeMesh([4, 4, 4]) + mesh.refine(-1) + + attr1 = getattr(mesh, attr) + attr2 = getattr(mesh, attr) + + assert attr1 is attr2 + + if __name__ == "__main__": unittest.main() diff --git a/tests/tree/test_tree_interpolation.py b/tests/tree/test_tree_interpolation.py index d4688364b..4614cc91f 100644 --- a/tests/tree/test_tree_interpolation.py +++ b/tests/tree/test_tree_interpolation.py @@ -1,7 +1,8 @@ import numpy as np -import unittest import discretize +import pytest + MESHTYPES = ["uniformTree"] # ['randomTree', 'uniformTree'] call2 = lambda fun, xyz: fun(xyz[:, 0], xyz[:, 1]) call3 = lambda fun, xyz: fun(xyz[:, 0], xyz[:, 1], xyz[:, 2]) @@ -36,191 +37,182 @@ MESHTYPES = ["uniformTree", "notatreeTree"] -class TestInterpolation2d(discretize.tests.OrderTest): - """Face interpolation is O(h) - Edge interpolation is O(h^2) - """ - - name = "Interpolation 2D" - # location_type = 'Ex' - X, Y = np.mgrid[0:1:250j, 0:1:250j] - LOCS = np.c_[X.reshape(-1), Y.reshape(-1)] - # LOCS = np.c_[np.ones(100)*0.51, np.linspace(0.3, 0.7, 100)] - meshTypes = MESHTYPES - # tolerance = TOLERANCES - meshDimension = 2 - meshSizes = [8, 16, 32] - expectedOrders = 1 - - def getError(self): - funX = lambda x, y: np.cos(2.0 * np.pi * y) * np.cos(2.0 * np.pi * x) + x - funY = lambda x, y: np.cos(2.0 * np.pi * x) * np.cos(2.0 * np.pi * y) + y - - # self.LOCS = self.M.gridCC - - if "x" in self.type: - ana = call2(funX, self.LOCS) - elif "y" in self.type: - ana = call2(funY, self.LOCS) - else: - ana = call2(funX, self.LOCS) - - if "F" in self.type: - Fc = cartF2(self.M, funX, funY) - grid = self.M.project_face_vector(Fc) - elif "E" in self.type: - Ec = cartE2(self.M, funX, funY) - grid = self.M.project_edge_vector(Ec) - elif "CC" == self.type: - grid = call2(funX, self.M.gridCC) - elif "N" == self.type: - grid = call2(funX, self.M.gridN) - - comp = self.M.get_interpolation_matrix(self.LOCS, self.type) * grid - - err = np.linalg.norm((comp - ana), np.inf) - if plotIt: - import matplotlib.pyplot as plt - - ax = plt.subplot(211) - self.M.plot_grid(ax=ax) - plt.plot(self.LOCS[:, 0], self.LOCS[:, 1], "mx") - # ax = plt.subplot(111) - # self.M.plot_image(call2(funX, self.M.gridCC), ax=ax) - ax = plt.subplot(212) - plt.plot(self.LOCS[:, 1], comp, "bx") - plt.plot(self.LOCS[:, 1], ana, "ro") - plt.show() - return err - - def test_orderCC(self): - self.type = "CC" - self.name = "Interpolation 2D: CC" - self.orderTest() - - def test_orderN(self): - self.type = "N" - self.name = "Interpolation 2D: N" - self.expectedOrders = 2 - self.orderTest() - self.expectedOrders = 1 - - def test_orderFx(self): - self.type = "Fx" - self.name = "TreeMesh Interpolation 2D: Fx" - self.orderTest() - - def test_orderFy(self): - self.type = "Fy" - self.name = "TreeMesh Interpolation 2D: Fy" - self.orderTest() - - def test_orderEx(self): - self.type = "Ex" - self.name = "TreeMesh Interpolation 2D: Ex" - self.orderTest() - - def test_orderEy(self): - self.type = "Ey" - self.name = "TreeMesh Interpolation 2D: Ey" - self.orderTest() - - -class TestInterpolation3D(discretize.tests.OrderTest): - name = "Interpolation" - X, Y, Z = np.mgrid[0:1:50j, 0:1:50j, 0:1:50j] - LOCS = np.c_[X.reshape(-1), Y.reshape(-1), Z.reshape(-1)] - meshTypes = MESHTYPES - # tolerance = TOLERANCES - meshDimension = 3 - meshSizes = [8, 16] - - def getError(self): - funX = lambda x, y, z: np.cos(2 * np.pi * y) - funY = lambda x, y, z: np.cos(2 * np.pi * z) - funZ = lambda x, y, z: np.cos(2 * np.pi * x) - - if "x" in self.type: - ana = call3(funX, self.LOCS) - elif "y" in self.type: - ana = call3(funY, self.LOCS) - elif "z" in self.type: - ana = call3(funZ, self.LOCS) - else: - ana = call3(funX, self.LOCS) - - if "F" in self.type: - Fc = cartF3(self.M, funX, funY, funZ) - grid = self.M.project_face_vector(Fc) - elif "E" in self.type: - Ec = cartE3(self.M, funX, funY, funZ) - grid = self.M.project_edge_vector(Ec) - elif "CC" == self.type: - grid = call3(funX, self.M.gridCC) - elif "N" == self.type: - grid = call3(funX, self.M.gridN) - - A = self.M.get_interpolation_matrix(self.LOCS, self.type) - comp = A * grid - - err = np.linalg.norm((comp - ana), np.inf) - return err - - def test_orderCC(self): - self.type = "CC" - self.expectedOrders = 1 - self.name = "Interpolation 3D: CC" - self.orderTest() - self.expectedOrders = 2 - - def test_orderN(self): - self.type = "N" - self.name = "Interpolation 3D: N" - self.orderTest() - - def test_orderFx(self): - self.type = "Fx" - self.name = "Interpolation 3D: Fx" - self.expectedOrders = 1 - self.orderTest() - self.expectedOrders = 2 - - def test_orderFy(self): - self.type = "Fy" - self.name = "Interpolation 3D: Fy" - self.expectedOrders = 1 - self.orderTest() - self.expectedOrders = 2 - - def test_orderFz(self): - self.type = "Fz" - self.name = "Interpolation 3D: Fz" - self.expectedOrders = 1 - self.orderTest() - self.expectedOrders = 2 - - def test_orderEx(self): - self.type = "Ex" - self.name = "Interpolation 3D: Ex" - self.orderTest() - - def test_orderEy(self): - self.type = "Ey" - self.name = "Interpolation 3D: Ey" - self.orderTest() - - def test_orderEz(self): - self.type = "Ez" - self.name = "Interpolation 3D: Ez" - self.orderTest() - - -class TestCaching(unittest.TestCase): - def setUp(self): - self.mesh, maxh = discretize.tests.setup_mesh("uniformTree", 32, 3) - - def testCaching(self): - mesh = self.mesh - A1 = mesh.average_edge_to_face - A2 = mesh.average_edge_to_face - self.assertIs(A1, A2) +@pytest.mark.parametrize("tree_type", ["uniformTree", "notatreeTree"]) +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("zeros_outside", [True, False]) +@pytest.mark.parametrize( + "mesh_locs", + [ + "cell_centers", + "nodes", + "edges_x", + "edges_y", + "edges_z", + "faces_x", + "faces_y", + "faces_z", + ], +) +def test_order(tree_type, dim, mesh_locs, zeros_outside): + if dim == 2 and "z" in mesh_locs: + pytest.skip() + + locs = ( + np.mgrid[ + *[ + slice(0.25, 0.75, 50j), + ] + * dim + ] + .reshape(dim, -1) + .transpose() + ) + + if "notatree" in tree_type: + expected_order = 2 + elif mesh_locs == "nodes": + expected_order = 2 + else: + expected_order = 1 + + def ana_func(locs): + return locs**2 * [2, -3, 4][:dim] + locs * [-4, 3, 2][:dim] + [2, 3, 4][:dim] + + ana_vals = ana_func(locs) + + if "faces" in mesh_locs: + source_attr = "faces" + elif "edges" in mesh_locs: + source_attr = "edges" + else: + source_attr = mesh_locs + + def order_func(n): + mesh, h = discretize.tests.setup_mesh(tree_type, n, dim) + interp_mat = mesh.get_interpolation_matrix( + locs, mesh_locs, zeros_outside=zeros_outside + ) + grid_vals = ana_func(getattr(mesh, source_attr)) + interp_vals = interp_mat @ grid_vals + + return np.linalg.norm(interp_vals - ana_vals), h + + discretize.tests.assert_expected_order( + order_func, + [8, 16, 32], + expected_order=expected_order, + test_type="mean_at_least", + ) + + +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize( + "mesh_locs", + [ + "cell_centers", + "nodes", + "edges_x", + "edges_y", + "edges_z", + "faces_x", + "faces_y", + "faces_z", + ], +) +def test_zeros_outside(dim, mesh_locs, zeros_outside): + if dim == 2 and "z" in mesh_locs: + pytest.skip() + + locs = ( + np.mgrid[ + *[ + slice(-1, 2, 3j), + ] + * dim + ] + .reshape(dim, -1) + .transpose() + ) + mesh = discretize.TreeMesh([16, 16, 16][:dim]) + mesh.refine(-1) + + is_outside = np.any((locs < 0) | (locs > 1), axis=1) + locs = locs[is_outside] + + interp_mat = mesh.get_interpolation_matrix(locs, mesh_locs, zeros_outside=True) + + if "faces" in mesh_locs: + n = mesh.n_faces + elif "edges" in mesh_locs: + n = mesh.n_edges + elif "nodes" in mesh_locs: + n = mesh.n_nodes + else: + n = mesh.n_cells + + vs = interp_mat @ np.ones(n) + + np.testing.assert_equal(vs, 0) + + +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize( + "mesh_locs", + [ + "cell_centers", + "nodes", + "edges_x", + "edges_y", + "edges_z", + "faces_x", + "faces_y", + "faces_z", + ], +) +def test_project_outside(dim, mesh_locs): + if dim == 2 and "z" in mesh_locs: + pytest.skip() + + locs = ( + np.mgrid[ + *[ + slice(-1, 2, 3j), + ] + * dim + ] + .reshape(dim, -1) + .transpose() + ) + mesh = discretize.TreeMesh([16, 16, 16][:dim], diagonal_balance=True) + mesh.refine(-1) + + is_outside = np.any((locs < 0) | (locs > 1), axis=1) + locs = locs[is_outside] + + grid_locs = getattr(mesh, mesh_locs) + source_bounds = [ + grid_locs.min(axis=0), + grid_locs.max(axis=0), + ] + interp_mat = mesh.get_interpolation_matrix(locs, mesh_locs, zeros_outside=False) + + def ana_func(locs): + locs = np.clip(locs, a_min=source_bounds[0], a_max=source_bounds[1]) + return locs * [-4, 3, 2][:dim] + [2, 3, 4][:dim] + + ana_vals = ana_func(locs) + + # get the full list of locations associate with mesh_locs + if "faces" in mesh_locs: + source_attr = "faces" + elif "edges" in mesh_locs: + source_attr = "edges" + else: + source_attr = mesh_locs + + source_locs = getattr(mesh, source_attr) + grid_vals = ana_func(source_locs) + + vs = interp_mat @ grid_vals + + np.testing.assert_equal(vs, ana_vals) From 0253e0a645c1ad04c32eabacd6b47dd31f2c5349 Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Tue, 2 Dec 2025 12:35:44 -0700 Subject: [PATCH 3/3] remove unused argument in test --- tests/tree/test_tree_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tree/test_tree_interpolation.py b/tests/tree/test_tree_interpolation.py index 4614cc91f..c4631d956 100644 --- a/tests/tree/test_tree_interpolation.py +++ b/tests/tree/test_tree_interpolation.py @@ -119,7 +119,7 @@ def order_func(n): "faces_z", ], ) -def test_zeros_outside(dim, mesh_locs, zeros_outside): +def test_zeros_outside(dim, mesh_locs): if dim == 2 and "z" in mesh_locs: pytest.skip()