diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 151798446..024722af0 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -32,6 +32,7 @@ from jax._src.state.types import AbstractRef import jax.experimental from jax.experimental import hijax as hjx +import jax.extend as jex import jax.tree_util as jtu import treescope # type: ignore[import-untyped] @@ -286,7 +287,14 @@ def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) return VariableQDD(leaf_types, self.treedef, self.var_type) -class VariableEffect(jax.core.Effect): ... +try: + # JAX v0.10.0 and newer. + Effect: type = jex.core.Effect +except AttributeError: + # JAX v0.9.2 and older. + Effect = jax.core.Effect + +class VariableEffect(Effect): ... variable_effect = VariableEffect()