Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions jaxpm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp
import jaxdecomp
from jax import lax, shard_map
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import AbstractMesh, Mesh, NamedSharding
from jax.sharding import PartitionSpec as P


Expand Down Expand Up @@ -114,17 +114,36 @@ def slice_unpad(x, pad_width, sharding):


def get_local_shape(mesh_shape, sharding=None):
""" Helper function to get the local size of a mesh given the global size.
"""
"""Helper function to get the local size of a mesh given the global size."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
return mesh_shape
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]
return list(mesh_shape)
axis_name_to_size = {
name: gpu_mesh.devices.shape[i]
for i, name in enumerate(gpu_mesh.axis_names)
}
local_shape = list(mesh_shape)
for dim_idx, axis_name in enumerate(sharding.spec):
if axis_name is not None and dim_idx < len(local_shape):
local_shape[dim_idx] //= axis_name_to_size[axis_name]
return local_shape


def get_sharding_for_shape(shape, sharding=None):
"""Trim sharding spec to match the dimensionality of the given shape.

E.g. a 3D NamedSharding P('x','y') becomes P('x') for a 1D HEALPix array.
Comment thread
ASKabalan marked this conversation as resolved.
"""
if sharding is None:
return None
gpu_mesh = sharding.mesh
if gpu_mesh is None or gpu_mesh.empty:
return sharding
ndim = len(shape)
spec = sharding.spec
if ndim < len(spec):
return NamedSharding(gpu_mesh, P(*spec[:ndim]))
return sharding


def _axis_names(spec):
Expand Down Expand Up @@ -176,6 +195,7 @@ def normal_field(seed, shape, sharding=None, dtype=float):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(shape, sharding)
sharding = get_sharding_for_shape(shape, sharding)

size = jax.device_count()
# rank = jax.process_index()
Expand Down
Loading