Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2878,9 +2878,21 @@ def pop(
return states


def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node:
def clone(
node: Node,
variables: bool = True,
*,
arrays: bool = False,
graph: bool | None = None,
) -> Node:
"""Create a deep copy of the given graph node.

By default the cloned :class:`Variable` wrappers are new objects but
share their underlying ``jax.Array`` buffers with the original. This
is cheap, but ``jit(donate_argnums=...)`` will reject the clone
because JAX refuses to donate the same buffer twice. Pass
``arrays=True`` to also copy the underlying buffers.

Example usage::

>>> from flax import nnx
Expand All @@ -2894,6 +2906,11 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
node: A graph node object.
variables: If ``True`` (default) copies of the :class:`Variable` objects are created,
otherwise the Variables are shared between the original and cloned node.
arrays: If ``True``, the underlying ``jax.Array`` buffers of each
:class:`Variable` are also copied so the clone has independent
memory and is compatible with ``jit(donate_argnums=...)``. Implies
``variables=True``. Defaults to ``False`` (buffers shared with the
original).
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references. If ``False``, uses
tree-mode which treats Modules as regular JAX pytrees, avoiding
Expand All @@ -2902,6 +2919,9 @@ def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> N
A deep copy of the :class:`Module` object.
"""
graphdef, state = split(node, graph=graph)
if arrays:
state = jax.tree.map(jax.numpy.copy, state)
variables = True
return merge(graphdef, state, copy=variables)


Expand Down
37 changes: 37 additions & 0 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,43 @@ def test_clone(self):
assert m.b.c.get_value() == m2.b.c.get_value()
assert m.b.d.get_value() == m2.b.d.get_value()

def test_clone_arrays_distinct_buffers(self):
m = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(0))

shared = nnx.clone(m)
distinct = nnx.clone(m, arrays=True)

# Default: buffers shared with the original.
assert (
m.kernel[...].unsafe_buffer_pointer()
== shared.kernel[...].unsafe_buffer_pointer()
)
# arrays=True: buffers are independent.
assert (
m.kernel[...].unsafe_buffer_pointer()
!= distinct.kernel[...].unsafe_buffer_pointer()
)
# Values still match.
np.testing.assert_array_equal(m.kernel[...], distinct.kernel[...])

def test_clone_arrays_donate_argnums(self):
class TestModule(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(in_features=4, out_features=2, rngs=rngs)
self.linear_copy = nnx.clone(self.linear, arrays=True)

m = TestModule(nnx.Rngs(0))

@jax.jit
def update(state):
return jax.tree.map(lambda x: x + 0.01, state)

update_donate = jax.jit(update, donate_argnums=(0,))
# Pre-fix this raised XlaRuntimeError ("donate the same buffer twice")
# because nnx.clone shared the linear and linear_copy buffers.
state = update_donate(nnx.state(m))
assert state is not None

def test_sow_existing_non_variable_field(self):
class Foo(nnx.Module):
def __init__(self) -> None:
Expand Down
Loading