Skip to content
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0a60877
add lensing and basic Spherical painting functionality
ASKabalan Sep 23, 2025
efcfcf2
Allow Cosmo Cache deactivation
ASKabalan Sep 23, 2025
8ce412d
add Symplectic ODE terms for fast PM integration
ASKabalan Sep 23, 2025
e21ebe4
add test against glass convergence maps
ASKabalan Sep 23, 2025
1c036ec
format
ASKabalan Sep 23, 2025
6f9b02b
add glass[examples] to requirements-test.txt
ASKabalan Sep 23, 2025
f5507ae
format
ASKabalan Sep 23, 2025
36ce358
format
ASKabalan Sep 24, 2025
e8264cc
add Jax healpy as a dependency
ASKabalan Sep 24, 2025
c692638
Roll back growth.py cache
ASKabalan Oct 7, 2025
8f60be6
Update testing dependencies
ASKabalan Oct 7, 2025
b746fe8
update lensing.py and test_convergence_vs_glass to fix jax-cosmo cach…
ASKabalan Oct 7, 2025
70d661e
Revert the fastpm odes to use Growth functions
ASKabalan Oct 7, 2025
72647d9
Add safe division by zero in lensing.py
ASKabalan Oct 7, 2025
5ca4caa
add save division by zero in spherical.py
ASKabalan Oct 7, 2025
e56ccca
Allow deactivating the usage of growth factors in ode.py in the FastP…
ASKabalan Oct 7, 2025
5336b36
Merge branch 'main' into 41-spherical-lensing
ASKabalan Oct 9, 2025
491e0ae
Fix flat sky lensing issues and update the painting method to use the…
ASKabalan Oct 14, 2025
ada5187
merge remote-tracking branch 'origin/main' into 41-spherical-lensing
ASKabalan Oct 14, 2025
0e41e9a
Merge remote-tracking branch 'origin/main' into 41-spherical-lensing
ASKabalan Oct 14, 2025
0bf8059
Add an analytical function to compute visibility mask from observer p…
ASKabalan Oct 22, 2025
e99502c
Merge branch 'main' into 41-spherical-lensing
ASKabalan Oct 22, 2025
0d4e2f0
format
ASKabalan Oct 24, 2025
a129acd
Update jax.experimental.shard_map to jax.shard_map and bump jax version
ASKabalan Oct 28, 2025
7351fc9
format
ASKabalan Oct 28, 2025
ae7c04e
Update shardmap signature
ASKabalan Oct 28, 2025
d0d29c9
Update JAX dependencies and migrate from experimental shard_map API (…
Copilot Oct 28, 2025
ec1af7d
Merge remote-tracking branch 'origin/update-shardmap-import' into 41-…
ASKabalan Oct 30, 2025
a430f4e
Update test workflow install order
ASKabalan Oct 30, 2025
b9eff44
Merge remote-tracking branch 'origin/update-shardmap-import' into upd…
ASKabalan Oct 30, 2025
9ff2eae
Add Set type import in distributed.py
ASKabalan Oct 30, 2025
81cd08a
Merge remote-tracking branch 'origin/update-shardmap-import' into 41-…
ASKabalan Oct 30, 2025
76b7713
add
ASKabalan Nov 6, 2025
9a672ad
temp commit
ASKabalan Nov 6, 2025
abd66a5
Commit
ASKabalan Nov 25, 2025
1622e9f
update jax-cosmo usage to match update in ASKabalan:add_growth_second…
ASKabalan Nov 27, 2025
21ba798
Simplify spherical painting (flatten later to use more devices)
ASKabalan Dec 15, 2025
3aecf0b
small fix
ASKabalan Dec 15, 2025
7017b33
remove print statements used for debugging
ASKabalan Dec 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,15 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
# Install JAX first as it's a key dependency
pip install jax
# Install packages with test dependencies
pip install -e .[test]
# Install build dependencies
pip install setuptools cython mpi4py
# Install test requirements with no-build-isolation for faster builds
# Install test requirements with no-build-isolation for PFFT
pip install -r requirements-test.txt --no-build-isolation
# Install additional test dependencies
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The installation order may cause issues: pip install -e .[test] is run before installing requirements-test.txt which contains numpy==2.2.6. This could lead to dependency conflicts if the project dependencies have numpy version constraints. Consider installing requirements-test.txt before running pip install -e .[test] to establish the numpy version first.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest is being installed redundantly here. It's already listed as a test dependency in pyproject.toml under [project.optional-dependencies].test and would have been installed by the earlier pip install -e .[test] command on line 55.

Suggested change
pip install pytest diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'
pip install diffrax 'glass[examples] @ git+https://github.com/glass-dev/glass'

Copilot uses AI. Check for mistakes.
# Install packages with test dependencies
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment 'Install packages with test dependencies' appears twice (lines 54 and 62) with different actual operations, which is confusing. The second occurrence at line 62 should be removed as it's followed by a simple echo statement, not another pip install command.

Suggested change
# Install packages with test dependencies

Copilot uses AI. Check for mistakes.
pip install -e .[test]
echo "numpy version installed:"
python -c "import numpy; print(numpy.__version__)"

Expand Down
15 changes: 10 additions & 5 deletions jaxpm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
Specs = Any
AxisName = Hashable

from collections.abc import Set
from functools import partial

import jax
import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax import lax, shard_map
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P

Expand All @@ -19,14 +19,19 @@ def autoshmap(
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
auto: frozenset[AxisName] = frozenset()) -> Callable:
check_vma: bool = False,
axis_names: Set[AxisName] = frozenset()) -> Callable:
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
if gpu_mesh is None or gpu_mesh.empty:
return f
else:
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
return shard_map(f,
mesh=gpu_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_vma=check_vma,
axis_names=axis_names)


def fft3d(x):
Expand Down
29 changes: 29 additions & 0 deletions jaxpm/growth.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,32 @@ def dGf2a(cosmo, a):
E_a = E(cosmo, a)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)


def gp(cosmo, a):
r""" Derivative of D1 against a

Parameters
----------
cosmo: dict
Cosmology dictionary.

a : array_like
Scale factor.

Returns
-------
Scalar float Tensor : the derivative of D1 against a.

Notes
-----

The expression for :math:`gp(a)` is:

.. math::
gp(a)=\frac{dD1}{da}= D'_{1norm}/a
"""
f1 = growth_rate(cosmo, a)
g1 = growth_factor(cosmo, a)
D1f = f1 * g1 / a
return D1f
255 changes: 193 additions & 62 deletions jaxpm/lensing.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,213 @@
import jax
import jax.numpy as jnp
import jax_cosmo
import jax_cosmo as jc
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates

from jaxpm.painting import cic_paint_2d
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'cic_paint' is not used.
Import of 'cic_paint_dx' is not used.

Suggested change
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
from jaxpm.painting import cic_paint_2d

Copilot uses AI. Check for mistakes.
from jaxpm.spherical import paint_particles_spherical
from jaxpm.utils import gaussian_smoothing


def density_plane(positions,
box_shape,
center,
width,
plane_resolution,
smoothing_sigma=None):
""" Extacts a density plane from the simulation
"""
nx, ny, nz = box_shape
xy = positions[..., :2]
d = positions[..., 2]
def density_plane_fn(box_shape,
box_size,
density_plane_width,
density_plane_npix,
sharding=None):

def f(t, y, args):
positions = y
cosmo = args
nx, ny, nz = box_shape

# Converts time t to comoving distance in voxel coordinates
w = density_plane_width / box_size[2] * box_shape[2]
center = jc.background.radial_comoving_distance(
cosmo, t) / box_size[2] * box_shape[2]
# Clear workspace to avoid memory issues and tracer leaks
# due to the caching system in jax-cosmo
cosmo._workspace = {}
positions = uniform_particles(box_shape) + positions
xy = positions[..., :2]
d = positions[..., 2]

# Apply 2d periodic conditions
xy = jnp.mod(xy, nx)

# Rescaling positions to target grid
xy = xy / nx * density_plane_npix
# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - w / 2)) & (d <= (center + w / 2)),
1.0, 0.0)
# Painting density plane
zero_mesh = jnp.zeros([density_plane_npix, density_plane_npix])
# Apply sharding in order to recover sharding when taking gradients
if sharding is not None:
xy = jax.lax.with_sharding_constraint(xy, sharding)
# Apply CIC painting
density_plane = cic_paint_2d(zero_mesh, xy, weight)

# Calculate physical volume per pixel
pixel_area = (box_size[0] / density_plane_npix) * (box_size[1] /
density_plane_npix)
shell_thickness_physical = density_plane_width # Already in physical units
pixel_volume = pixel_area * shell_thickness_physical

# Convert counts to density (particles per unit volume)
density_plane = density_plane / pixel_volume

return density_plane

return f


def spherical_density_fn(mesh_shape,
box_size,
nside,
observer_position,
density_plane_width,
method="RBF_NEIGHBOR",
kernel_width_arcmin=None,
sharding=None):

def f(t, y, args):
positions = y
cosmo = args

# Apply 2d periodic conditions
xy = jnp.mod(xy, nx)
positions = uniform_particles(mesh_shape) + positions

# Rescaling positions to target grid
xy = xy / nx * plane_resolution
# Calculate comoving distance range for this shell
r_center = jc.background.radial_comoving_distance(cosmo, t)
# Clear workspace to avoid memory issues
# due to the caching system in jax-cosmo
cosmo._workspace = {}
r_max = jnp.clip(r_center + density_plane_width / 2, 0, box_size[2])
r_min = jnp.clip(r_center - density_plane_width / 2, 0, box_size[2])

# Selecting only particles that fall inside the volume of interest
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
if sharding is not None:
positions = jax.lax.with_sharding_constraint(positions, sharding)

# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
(ny / plane_resolution) * (width))
# Paint particles in this shell onto a HEALPix map
spherical_map = paint_particles_spherical(
positions,
nside=nside,
method=method,
kernel_width_arcmin=kernel_width_arcmin,
observer_position=observer_position,
R_min=r_min,
R_max=r_max,
box_size=box_size,
mesh_shape=mesh_shape)

# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
return spherical_map

return density_plane
return f


def convergence_Born(cosmo, density_planes, coords, z_source):
# ==========================================================
# Weak Lensing Born Approximation
# ==========================================================
def convergence_Born(cosmo,
density_planes,
r,
a,
z_source,
d_r,
dx=None,
coords=None):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
Born approximation convergence for both spherical and flat geometries.

Parameters
----------
cosmo : jc.Cosmology
Cosmology object
density_planes : ndarray
- Spherical: [n_planes, npix] - density on HEALPix grid
- Flat: [n_planes, nx, ny] - density on Cartesian grid
Note: d_R is already included in the density normalization
r : ndarray
Comoving distances to plane centers [n_planes]
a : ndarray
Scale factors at plane centers [n_planes]
z_source : float or ndarray
Source redshift(s)
dx : float, optional
Pixel size for flat-sky case (required for flat)
coords : ndarray, optional
Angular coordinates for flat-sky (required for flat)

Returns
-------
convergence : ndarray
Convergence map
"""
# Constants
Comment thread
ASKabalan marked this conversation as resolved.
# --- 1. Pre-computation and Shape Setup ---
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))

convergence = 0
for entry in density_planes:
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization

# Interpolate at the density plane coordinates
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")

convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
chi_s = jc.background.radial_comoving_distance(cosmo,
jc.utils.z2a(z_source))
n_planes = len(r)

# Detect geometry from input shape
is_spherical = density_planes.ndim == 2 # [n_planes, npix]

if not is_spherical:
assert dx is not None and coords is not None, "dx and coords are required for flat geometry."

# Reshape 1D arrays to [n_planes, 1, 1] for broadcasting with [n_planes, nx, ny]
# Or to [n_planes, 1] for spherical geometry
r_b = r.reshape(n_planes, *((1, ) * (density_planes.ndim - 1)))
a_b = a.reshape(n_planes, *((1, ) * (density_planes.ndim - 1)))

# --- 2. Vectorized Overdensity Calculation ---
# Calculate mean density across spatial dimensions for each plane
mean_axes = tuple(range(1, density_planes.ndim))
rho_mean = jnp.mean(density_planes, axis=mean_axes, keepdims=True)
# Avoid division by zero by adding a small epsilon where mean density is zero
eps = jnp.finfo(rho_mean.dtype).eps
safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean)
Comment thread
ASKabalan marked this conversation as resolved.
delta = density_planes / safe_rho_mean - 1
Comment on lines +165 to +166
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When rho_mean is exactly 0, safe_rho_mean is set to eps, making delta = density_planes/eps - 1. If density_planes contains zeros (which is expected when rho_mean is 0), this results in -1.0. However, if density_planes contains non-zero values while rho_mean is 0, this will create artificially large delta values. Consider using jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1) to handle the zero-mean case more directly.

Suggested change
safe_rho_mean = jnp.where(rho_mean == 0, eps, rho_mean)
delta = density_planes / safe_rho_mean - 1
# Use jnp.where to set delta=0 when rho_mean==0, avoiding artificially large values
delta = jnp.where(rho_mean == 0, 0.0, density_planes / rho_mean - 1)

Copilot uses AI. Check for mistakes.

# --- 3. Vectorized Lensing Kernel and Weighting ---
# Combine all factors except interpolation
# This includes the geometric term: dχ * χ / a(χ)
kappa_contributions = delta * (d_r * r_b / a_b)
kappa_contributions *= constant_factor
# --- 4. Interpolation (for Flat-Sky only) ---
if not is_spherical:
# Define the interpolation function for a SINGLE plane
def interpolate_plane(delta_plane, chi_plane):
physical_coords = coords * chi_plane / dx
return map_coordinates(delta_plane,
physical_coords - 0.5,
order=1,
mode="wrap")

# Use vmap to apply the function across all planes efficiently
kappa_contributions = jax.vmap(interpolate_plane)(kappa_contributions,
r)

# --- 5. Final Assembly ---
# In case of multiple source redshifts, and a flat sky approximation,
# We need to add a dimension to match the 2D shape of the kappa contributions
if jnp.ndim(z_source) > 0 and not is_spherical:
chi_s = jnp.expand_dims(chi_s, axis=1)
# Apply the constant factor and the lensing efficiency kernel: (χs - χ) / χs
lensing_efficiency = jnp.clip(1.0 - (r_b / chi_s), 0, 1000)
# Add a dimension for broadcasting the redshift dimension
lensing_efficiency = jnp.expand_dims(lensing_efficiency, axis=-1)
kappa_contributions = jnp.expand_dims(kappa_contributions, axis=1)
# Multiply the weighted delta by the lensing kernel and constant
final_contributions = lensing_efficiency * kappa_contributions

# Sum contributions from all planes to get the final convergence map
# For multiple redshifts, preserve the redshift dimension
convergence = jnp.sum(final_contributions, axis=0)

# Handle single vs multiple redshift cases
if jnp.ndim(z_source) == 0: # Single redshift case
convergence = jnp.squeeze(convergence, axis=0)

return convergence
Loading
Loading