From c8e3120fa8156066f21b47ba9d5e083327d3301f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 15 Apr 2026 11:38:44 -0700 Subject: [PATCH] Avoid usage of deprecated `jax.core` APIs. These APIs are deprecated as of JAX v0.10.0, replaced by equivalents in `jax.extend.core` (see https://docs.jax.dev/en/latest/jax.extend.html for details). PiperOrigin-RevId: 900270194 --- flax/nnx/variablelib.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 151798446..024722af0 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -32,6 +32,7 @@ from jax._src.state.types import AbstractRef import jax.experimental from jax.experimental import hijax as hjx +import jax.extend as jex import jax.tree_util as jtu import treescope # type: ignore[import-untyped] @@ -286,7 +287,14 @@ def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) return VariableQDD(leaf_types, self.treedef, self.var_type) -class VariableEffect(jax.core.Effect): ... +try: + # JAX v0.10.0 and newer. + Effect: type = jex.core.Effect +except AttributeError: + # JAX v0.9.2 and older. + Effect = jax.core.Effect + +class VariableEffect(Effect): ... variable_effect = VariableEffect()