diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..903a415d3 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -2878,9 +2878,21 @@ def pop( return states -def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node: +def clone( + node: Node, + variables: bool = True, + *, + arrays: bool = False, + graph: bool | None = None, +) -> Node: """Create a deep copy of the given graph node. + By default the cloned :class:`Variable` wrappers are new objects but + share their underlying ``jax.Array`` buffers with the original. This + is cheap, but ``jit(donate_argnums=...)`` will reject the clone + because JAX refuses to donate the same buffer twice. Pass + ``arrays=True`` to also copy the underlying buffers. + Example usage:: >>> from flax import nnx @@ -2894,6 +2906,11 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N node: A graph node object. variables: If ``True`` (default) copies of the :class:`Variable` objects are created, otherwise the Variables are shared between the original and cloned node. + arrays: If ``True``, the underlying ``jax.Array`` buffers of each + :class:`Variable` are also copied so the clone has independent + memory and is compatible with ``jit(donate_argnums=...)``. Implies + ``variables=True``. Defaults to ``False`` (buffers shared with the + original). graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding @@ -2902,6 +2919,9 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N A deep copy of the :class:`Module` object. """ graphdef, state = split(node, graph=graph) + if arrays: + state = jax.tree.map(jax.numpy.copy, state) + variables = True return merge(graphdef, state, copy=variables) diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index be350e8a9..5421ebdb9 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -574,6 +574,43 @@ def test_clone(self): assert m.b.c.get_value() == m2.b.c.get_value() assert m.b.d.get_value() == m2.b.d.get_value() + def test_clone_arrays_distinct_buffers(self): + m = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(0)) + + shared = nnx.clone(m) + distinct = nnx.clone(m, arrays=True) + + # Default: buffers shared with the original. + assert ( + m.kernel[...].unsafe_buffer_pointer() + == shared.kernel[...].unsafe_buffer_pointer() + ) + # arrays=True: buffers are independent. + assert ( + m.kernel[...].unsafe_buffer_pointer() + != distinct.kernel[...].unsafe_buffer_pointer() + ) + # Values still match. + np.testing.assert_array_equal(m.kernel[...], distinct.kernel[...]) + + def test_clone_arrays_donate_argnums(self): + class TestModule(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(in_features=4, out_features=2, rngs=rngs) + self.linear_copy = nnx.clone(self.linear, arrays=True) + + m = TestModule(nnx.Rngs(0)) + + @jax.jit + def update(state): + return jax.tree.map(lambda x: x + 0.01, state) + + update_donate = jax.jit(update, donate_argnums=(0,)) + # Pre-fix this raised XlaRuntimeError ("donate the same buffer twice") + # because nnx.clone shared the linear and linear_copy buffers. + state = update_donate(nnx.state(m)) + assert state is not None + def test_sow_existing_non_variable_field(self): class Foo(nnx.Module): def __init__(self) -> None: