Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs_nnx/api_reference/flax.nnx/nn/attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ Attention
:module: flax.nnx
:class: MultiHeadAttention

.. flax_module::
:module: flax.nnx
:class: RoPE

.. autofunction:: combine_masks
.. autofunction:: dot_product_attention
.. autofunction:: dot_product_attention_with_rope
.. autofunction:: make_attention_mask
.. autofunction:: make_causal_mask
.. autofunction:: make_causal_mask
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
from .nn.attention import MultiHeadAttention as MultiHeadAttention
from .nn.attention import combine_masks as combine_masks
from .nn.attention import dot_product_attention as dot_product_attention
from .nn.attention import dot_product_attention_with_rope as dot_product_attention_with_rope
from .nn.attention import RoPE as RoPE
from .nn.attention import make_attention_mask as make_attention_mask
from .nn.attention import make_causal_mask as make_causal_mask
from .nn.recurrent import RNNCellBase as RNNCellBase
Expand Down
132 changes: 131 additions & 1 deletion flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,128 @@ def reshape_4d(x):
return out


class RoPE(Module):
Comment thread
samanklesaria marked this conversation as resolved.
"""Rotary Position Embedding (RoPE).

Precomputes and stores cosine and sine frequency tensors for rotary
positional encoding as described in "RoFormer: Enhanced Transformer
with Rotary Position Embedding" (https://arxiv.org/abs/2104.09864).

Example usage::

>>> from flax import nnx
>>> import functools
>>> rope = nnx.RoPE(embedding_size=64, max_seq_len=128)
>>> layer = nnx.MultiHeadAttention(
... num_heads=8, in_features=512, qkv_features=512,
... attention_fn=functools.partial(
... nnx.dot_product_attention_with_rope, rope=rope),
... decode=False, rngs=nnx.Rngs(0))

Args:
theta: Base frequency for sinusoidal functions. Default is 10000.
embedding_size: Optional. Size of each embedding vector (i.e. head_dim). Must be
even. If provided together with ``max_seq_len``, sin/cos tables are
precomputed and cached as module Variables.
max_seq_len: Optional. Maximum sequence length for precomputed frequencies. Must
be provided together with ``embedding_size``.
"""

def __init__(
self,
theta: float = 10000.0,
embedding_size: int | None = None,
max_seq_len: int | None = None,
):
self.theta = theta
if sum([embedding_size is None, max_seq_len is None]) % 2 != 0:
raise ValueError('Either both `embedding_size` and `max_seq_len` or none of them must be provided.')
if embedding_size is not None and max_seq_len is not None:
positions = jnp.arange(float(max_seq_len))
sin, cos = self._rotation_for(embedding_size, positions)
self.cached_sin = nnx.Variable(sin)
self.cached_cos = nnx.Variable(cos)

@staticmethod
def _rotate_half(x: Array) -> Array:
d_2 = x.shape[-1] // 2
return jnp.concatenate([-x[..., d_2:], x[..., :d_2]], axis=-1)

def _rotation_for(self, embedding_size: int, input_positions: Array):
seq_len = input_positions.shape[-1]
if embedding_size <= 0 or embedding_size % 2 != 0:
raise ValueError('`embedding_size` must be positive and even.')
freqs = 1.0 / (
self.theta ** (jnp.arange(0.0, embedding_size, 2) / embedding_size)
)
freqs_outer = jnp.outer(input_positions, freqs)
return jnp.sin(freqs_outer), jnp.cos(freqs_outer)

def __call__(self, x: Array, input_positions: Array | None = None) -> Array:
"""Apply rotary positional encoding.

Args:
x: Input array of shape ``(..., seq_length, embedding_size)``.
input_positions: Optional position array of shape ``(seq_len,)``.
If ``None`` and precomputed tables exist (i.e. ``embedding_size``
and ``max_seq_len`` were given at init), the cached tables are
used. Otherwise positions default to ``jnp.arange(seq_len)``.

Returns:
Array of same shape with RoPE applied.
"""
seq_len, embedding_size = x.shape[-2:]
if input_positions is None and hasattr(self, 'cached_sin'):
sin = self.cached_sin.value[:seq_len]
cos = self.cached_cos.value[:seq_len]
else:
if input_positions is None:
input_positions = jnp.arange(float(seq_len))
elif input_positions.shape[-1] != seq_len:
raise ValueError('`input_positions` must have the same seq_len as `x`.')
sin, cos = self._rotation_for(embedding_size, input_positions)
return x * jnp.tile(cos, (1, 2)) + self._rotate_half(x) * jnp.tile(sin, (1, 2))


def dot_product_attention_with_rope(
query: Array,
key: Array,
value: Array,
*,
rope: RoPE,
input_positions: Array | None = None,
**kwargs,
) -> Array:
"""Dot-product attention with Rotary Position Embedding applied to query and key.

This function has the same signature as :func:`dot_product_attention`
(plus a ``rope`` keyword argument) and can be used as the ``attention_fn``
of :class:`MultiHeadAttention` via ``functools.partial``::

attention_fn=functools.partial(dot_product_attention_with_rope, rope=rope)

Args:
query: queries of shape ``[batch..., q_length, num_heads, head_dim]``.
key: keys of shape ``[batch..., kv_length, num_heads, head_dim]``.
value: values of shape ``[batch..., kv_length, num_heads, v_dim]``.
rope: A :class:`RoPE` module instance.
**kwargs: Additional keyword arguments forwarded to
:func:`dot_product_attention`.

Returns:
Output of shape ``[batch..., q_length, num_heads, v_dim]``.
"""
# query/key: [batch..., seq_length, num_heads, head_dim]
# RoPE.__call__ expects (..., seq_length, head_dim), so vmap over heads.
if input_positions is None:
apply = jax.vmap(rope, in_axes=-2, out_axes=-2)
else:
apply = jax.vmap(rope, in_axes=(-2, -1), out_axes=(-2, -1))
query = apply(query, input_positions)
key = apply(key, input_positions)
return dot_product_attention(query, key, value, **kwargs)


class MultiHeadAttention(Module):
"""Multi-head attention.

Expand Down Expand Up @@ -588,6 +710,7 @@ def __call__(
is_causal=False,
out_sharding = None,
qkv_sharding = None,
input_positions: Array | None = None
):
"""Applies multi-head dot product attention on the input data.

Expand Down Expand Up @@ -624,6 +747,9 @@ def __call__(
the output linear layer for the output arrays.
qkv_sharding: Optional sharding specification to pass to
the QKV linear layers for the output arrays.
input_positions: Optional position indices of shape
``[batch_sizes..., length]`` forwarded to the ``attention_fn``.
Used by :func:`dot_product_attention_with_rope` for RoPE.

Returns:
output of shape `[batch_sizes..., length, features]`.
Expand Down Expand Up @@ -736,6 +862,9 @@ def __call__(
dropout_rng = None

# apply attention
attn_kwargs = {}
if input_positions is not None:
attn_kwargs["input_positions"] = input_positions
x = self.attention_fn(
query,
key,
Expand All @@ -748,7 +877,8 @@ def __call__(
dtype=self.dtype,
precision=self.precision,
module=self if sow_weights else None,
is_causal=is_causal
is_causal=is_causal,
**attn_kwargs
)
# back to the original inputs dimensions
out = self.out(x, out_sharding=out_sharding)
Expand Down
38 changes: 36 additions & 2 deletions tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,41 @@ def _run(m):
np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3)


if __name__ == '__main__':
absltest.main()
class TestRoPE(absltest.TestCase):

def test_relative_position_invariance(self):
"""Dot product of RoPE-rotated vectors depends only on relative position.

For any offset d, <RoPE(q, pos=i), RoPE(k, pos=j)> should equal
<RoPE(q, pos=i+d), RoPE(k, pos=j+d)>.
"""
head_dim = 64
max_len = 128
rope = nnx.RoPE()

k1, k2 = jax.random.split(jax.random.key(7))
q_vec = jax.random.normal(k1, (head_dim,))
k_vec = jax.random.normal(k2, (head_dim,))

# Place q at position i and k at position j, then shift both by d.
i, j = 5, 12
d = 37

def rope_at(vec, pos):
# Build a dummy sequence long enough, apply RoPE, extract one position.
seq = jnp.zeros((pos + 1, head_dim)).at[pos].set(vec)
return rope(seq)[pos]

q_rot = rope_at(q_vec, i)
k_rot = rope_at(k_vec, j)
dot_original = jnp.dot(q_rot, k_rot)

q_rot_shifted = rope_at(q_vec, i + d)
k_rot_shifted = rope_at(k_vec, j + d)
dot_shifted = jnp.dot(q_rot_shifted, k_rot_shifted)

np.testing.assert_allclose(dot_original, dot_shifted, atol=1e-5)


if __name__ == '__main__':
absltest.main()
Loading