Skip to content

update spherical and 2d painting to work better for sharding#53

Open
ASKabalan wants to merge 6 commits intomainfrom
better-sharded-spherical-painting
Open

update spherical and 2d painting to work better for sharding#53
ASKabalan wants to merge 6 commits intomainfrom
better-sharded-spherical-painting

Conversation

@ASKabalan
Copy link
Copy Markdown
Member

@ASKabalan ASKabalan commented Feb 12, 2026

Updating the painting functions to better support sharding and be more memory efficient

2D Painting

2D Painting now will give a sharded flat sky projection and will only use communications for the scattering done in the end to aggregate contribution of particles

Spherical Painting

Now uses all devices for computations (healpy ang2pix 2vec etc ..) and will never all gather .. will use communications at the very end to produce the fully replicated healpix map

I could consider even sharding the healpix map on 1d, however this might not be very efficient since we will eventually need to gather it for the alm2map_spin (so I will not consider sharding spherical maps for now)

Results

for spherical painting for example

=== Analyzing Scheme: ngp [OLD] ===
[NGP] pixels sharding: PartitionSpec('x', 'y') before bincount
Memory Allocations: 512.00 KB
Result Sharding for ngp [OLD]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
WARNING: all-gather detected in HLO

=== Analyzing Scheme: ngp [NEW] ===
[NGP] pixels sharding: PartitionSpec('x', 'y') before scatter
Memory Allocations: 352.00 KB
Result Sharding for ngp [NEW]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
No all-gather detected

=== Analyzing Scheme: rbf_neighbor [OLD] ===
/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax/_src/ops/scatter.py:104: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion=standard. In future JAX releases this will result in an error.
  warnings.warn(
[RBF] pixels sharding: PartitionSpec() before scatter
Memory Allocations: 17.38 MB
Result Sharding for rbf_neighbor [OLD]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
WARNING: all-gather detected in HLO

=== Analyzing Scheme: rbf_neighbor [NEW] ===
[RBF] pixels sharding: PartitionSpec(None, 'x', 'y') before scatter
Memory Allocations: 3.34 MB
Result Sharding for rbf_neighbor [NEW]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
No all-gather detected

=== Analyzing Scheme: bilinear [OLD] ===
[BILINEAR] pixels sharding: PartitionSpec(None, 'x', 'y') before scatter
Memory Allocations: 2.50 MB
Result Sharding for bilinear [OLD]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
WARNING: all-gather detected in HLO

=== Analyzing Scheme: bilinear [NEW] ===
[BILINEAR] pixels sharding: PartitionSpec(None, 'x', 'y') before scatter
Memory Allocations: 832.00 KB
Result Sharding for bilinear [NEW]: NamedSharding(mesh=Mesh('x': 2, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(), memory_kind=device)
No all-gather detected

this PR depends on CMBSciPol/jax-healpy#2 and it should be merged after

@ASKabalan ASKabalan mentioned this pull request Mar 2, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant