diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..fc00c83ba 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -3627,7 +3627,7 @@ class Static(tp.Generic[A]): class GenericPytree: ... -from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY + def is_pytree_node( @@ -3637,7 +3637,7 @@ def is_pytree_node( return False elif isinstance(x, Variable): return False - elif type(x) in JAX_PYTREE_REGISTRY: + elif jax.tree_util.is_tree_node(type(x)): return True elif isinstance(x, tuple): return True