diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 5ee6a11b..ffdc5442 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -10,7 +10,7 @@ import jax.numpy as jnp import jaxdecomp from jax import lax, shard_map -from jax.sharding import AbstractMesh, Mesh +from jax.sharding import AbstractMesh, Mesh, NamedSharding from jax.sharding import PartitionSpec as P @@ -114,17 +114,36 @@ def slice_unpad(x, pad_width, sharding): def get_local_shape(mesh_shape, sharding=None): - """ Helper function to get the local size of a mesh given the global size. - """ + """Helper function to get the local size of a mesh given the global size.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is None or gpu_mesh.empty: - return mesh_shape - else: - pdims = gpu_mesh.devices.shape - return [ - mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], - *mesh_shape[2:] - ] + return list(mesh_shape) + axis_name_to_size = { + name: gpu_mesh.devices.shape[i] + for i, name in enumerate(gpu_mesh.axis_names) + } + local_shape = list(mesh_shape) + for dim_idx, axis_name in enumerate(sharding.spec): + if axis_name is not None and dim_idx < len(local_shape): + local_shape[dim_idx] //= axis_name_to_size[axis_name] + return local_shape + + +def get_sharding_for_shape(shape, sharding=None): + """Trim sharding spec to match the dimensionality of the given shape. + + E.g. a 3D NamedSharding P('x','y') becomes P('x') for a 1D HEALPix array. + """ + if sharding is None: + return None + gpu_mesh = sharding.mesh + if gpu_mesh is None or gpu_mesh.empty: + return sharding + ndim = len(shape) + spec = sharding.spec + if ndim < len(spec): + return NamedSharding(gpu_mesh, P(*spec[:ndim])) + return sharding def _axis_names(spec): @@ -176,6 +195,7 @@ def normal_field(seed, shape, sharding=None, dtype=float): gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(shape, sharding) + sharding = get_sharding_for_shape(shape, sharding) size = jax.device_count() # rank = jax.process_index() diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index 0c81e012..c6cf14d0 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -13,12 +13,12 @@ from helpers import MSE # noqa : E402 from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402 -from jax.sharding import NamedSharding +from jax.sharding import AxisType, NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 from jaxdecomp import get_fft_output_sharding from jaxpm.distributed import uniform_particles # noqa : E402 -from jaxpm.distributed import fft3d, ifft3d +from jaxpm.distributed import fft3d, ifft3d, normal_field # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 @@ -92,7 +92,8 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, print("Done with single device run") # MULTI DEVICE RUN - mesh = jax.make_mesh(pdims, ('x', 'y')) + mesh = jax.make_mesh(pdims, ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -187,7 +188,8 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, # SINGLE DEVICE RUN cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) + mesh = jax.make_mesh(pdims, ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -264,7 +266,8 @@ def test_fwd_rev_gradients(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) + mesh = jax.make_mesh(pdims, ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -333,7 +336,8 @@ def test_vmap(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) + mesh = jax.make_mesh(pdims, ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -402,3 +406,24 @@ def fn(ic): assert sharded_forces[0].sharding.is_equivalent_to( initial_conditions.sharding, ndim=3) assert sharded_forces.sharding.spec[0] == None + + +@pytest.mark.distributed +@pytest.mark.parametrize("dim", [1, 2, 3]) +@pytest.mark.parametrize("pdims", pdims) +def test_normal_field(dim, pdims): + + shape = (16, ) * dim + + mesh = jax.make_mesh(pdims, ('x', 'y'), + axis_types=(AxisType.Auto, AxisType.Auto)) + sharding = NamedSharding(mesh, P('x', 'y')) + + dist_field = normal_field(seed=jax.random.PRNGKey(42), + shape=shape, + sharding=sharding) + if dim == 1: + sharding = NamedSharding(mesh, P('x')) + + assert dist_field.shape == shape + assert sharding.is_equivalent_to(dist_field.sharding, ndim=dim)