Skip to content

Allow 1d and 2d normal fields in distributed mode#54

Open
ASKabalan wants to merge 2 commits intomainfrom
extend-normal_field
Open

Allow 1d and 2d normal fields in distributed mode#54
ASKabalan wants to merge 2 commits intomainfrom
extend-normal_field

Conversation

@ASKabalan
Copy link
Copy Markdown
Member

This pull request introduces improvements to the handling of distributed array shapes and sharding specifications in jaxpm/distributed.py. The main focus is on making sharding logic more robust and flexible, especially when working with arrays of varying dimensionality. The most important changes are grouped below:

Enhancements to sharding and shape handling:

  • Improved get_local_shape to more accurately compute local shapes by mapping axis names to device sizes and handling cases where the mesh or sharding spec does not match the array dimensionality.
  • Added get_sharding_for_shape, a new helper function that trims the sharding specification to match the dimensionality of the target array, ensuring correct sharding for lower-dimensional arrays.
  • Updated normal_field to call get_sharding_for_shape before using sharding, ensuring the sharding spec matches the array shape.

API and import improvements:

  • Extended imports to include NamedSharding from jax.sharding, supporting the new sharding logic.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances the distributed array handling in jaxpm/distributed.py to support 1D and 2D arrays in distributed mode, particularly for HEALPix spherical arrays. The changes make sharding logic more flexible and robust when working with arrays of varying dimensionality.

Changes:

  • Refactored get_local_shape to dynamically map axis names to device sizes and handle arrays with fewer dimensions than the sharding specification
  • Added get_sharding_for_shape helper function to trim sharding specs to match array dimensionality
  • Updated normal_field to use the new trimming logic before applying sharding

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread jaxpm/distributed.py
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@ASKabalan ASKabalan mentioned this pull request Mar 2, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants