diff --git a/docs_nnx/api_reference/flax.nnx/nn/attention.rst b/docs_nnx/api_reference/flax.nnx/nn/attention.rst index 0317f60e9..1f60d70ea 100644 --- a/docs_nnx/api_reference/flax.nnx/nn/attention.rst +++ b/docs_nnx/api_reference/flax.nnx/nn/attention.rst @@ -8,7 +8,12 @@ Attention :module: flax.nnx :class: MultiHeadAttention +.. flax_module:: + :module: flax.nnx + :class: RoPE + .. autofunction:: combine_masks .. autofunction:: dot_product_attention +.. autofunction:: dot_product_attention_with_rope .. autofunction:: make_attention_mask -.. autofunction:: make_causal_mask \ No newline at end of file +.. autofunction:: make_causal_mask diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 78a550411..fd02e8e39 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -115,6 +115,8 @@ from .nn.attention import MultiHeadAttention as MultiHeadAttention from .nn.attention import combine_masks as combine_masks from .nn.attention import dot_product_attention as dot_product_attention +from .nn.attention import dot_product_attention_with_rope as dot_product_attention_with_rope +from .nn.attention import RoPE as RoPE from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask from .nn.recurrent import RNNCellBase as RNNCellBase diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..01dbe56a5 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -319,6 +319,128 @@ def reshape_4d(x): return out +class RoPE(Module): + """Rotary Position Embedding (RoPE). + + Precomputes and stores cosine and sine frequency tensors for rotary + positional encoding as described in "RoFormer: Enhanced Transformer + with Rotary Position Embedding" (https://arxiv.org/abs/2104.09864). + + Example usage:: + + >>> from flax import nnx + >>> import functools + >>> rope = nnx.RoPE(embedding_size=64, max_seq_len=128) + >>> layer = nnx.MultiHeadAttention( + ... num_heads=8, in_features=512, qkv_features=512, + ... attention_fn=functools.partial( + ... nnx.dot_product_attention_with_rope, rope=rope), + ... decode=False, rngs=nnx.Rngs(0)) + + Args: + theta: Base frequency for sinusoidal functions. Default is 10000. + embedding_size: Optional. Size of each embedding vector (i.e. head_dim). Must be + even. If provided together with ``max_seq_len``, sin/cos tables are + precomputed and cached as module Variables. + max_seq_len: Optional. Maximum sequence length for precomputed frequencies. Must + be provided together with ``embedding_size``. + """ + + def __init__( + self, + theta: float = 10000.0, + embedding_size: int | None = None, + max_seq_len: int | None = None, + ): + self.theta = theta + if sum([embedding_size is None, max_seq_len is None]) % 2 != 0: + raise ValueError('Either both `embedding_size` and `max_seq_len` or none of them must be provided.') + if embedding_size is not None and max_seq_len is not None: + positions = jnp.arange(float(max_seq_len)) + sin, cos = self._rotation_for(embedding_size, positions) + self.cached_sin = nnx.Variable(sin) + self.cached_cos = nnx.Variable(cos) + + @staticmethod + def _rotate_half(x: Array) -> Array: + d_2 = x.shape[-1] // 2 + return jnp.concatenate([-x[..., d_2:], x[..., :d_2]], axis=-1) + + def _rotation_for(self, embedding_size: int, input_positions: Array): + seq_len = input_positions.shape[-1] + if embedding_size <= 0 or embedding_size % 2 != 0: + raise ValueError('`embedding_size` must be positive and even.') + freqs = 1.0 / ( + self.theta ** (jnp.arange(0.0, embedding_size, 2) / embedding_size) + ) + freqs_outer = jnp.outer(input_positions, freqs) + return jnp.sin(freqs_outer), jnp.cos(freqs_outer) + + def __call__(self, x: Array, input_positions: Array | None = None) -> Array: + """Apply rotary positional encoding. + + Args: + x: Input array of shape ``(..., seq_length, embedding_size)``. + input_positions: Optional position array of shape ``(seq_len,)``. + If ``None`` and precomputed tables exist (i.e. ``embedding_size`` + and ``max_seq_len`` were given at init), the cached tables are + used. Otherwise positions default to ``jnp.arange(seq_len)``. + + Returns: + Array of same shape with RoPE applied. + """ + seq_len, embedding_size = x.shape[-2:] + if input_positions is None and hasattr(self, 'cached_sin'): + sin = self.cached_sin.value[:seq_len] + cos = self.cached_cos.value[:seq_len] + else: + if input_positions is None: + input_positions = jnp.arange(float(seq_len)) + elif input_positions.shape[-1] != seq_len: + raise ValueError('`input_positions` must have the same seq_len as `x`.') + sin, cos = self._rotation_for(embedding_size, input_positions) + return x * jnp.tile(cos, (1, 2)) + self._rotate_half(x) * jnp.tile(sin, (1, 2)) + + +def dot_product_attention_with_rope( + query: Array, + key: Array, + value: Array, + *, + rope: RoPE, + input_positions: Array | None = None, + **kwargs, +) -> Array: + """Dot-product attention with Rotary Position Embedding applied to query and key. + + This function has the same signature as :func:`dot_product_attention` + (plus a ``rope`` keyword argument) and can be used as the ``attention_fn`` + of :class:`MultiHeadAttention` via ``functools.partial``:: + + attention_fn=functools.partial(dot_product_attention_with_rope, rope=rope) + + Args: + query: queries of shape ``[batch..., q_length, num_heads, head_dim]``. + key: keys of shape ``[batch..., kv_length, num_heads, head_dim]``. + value: values of shape ``[batch..., kv_length, num_heads, v_dim]``. + rope: A :class:`RoPE` module instance. + **kwargs: Additional keyword arguments forwarded to + :func:`dot_product_attention`. + + Returns: + Output of shape ``[batch..., q_length, num_heads, v_dim]``. + """ + # query/key: [batch..., seq_length, num_heads, head_dim] + # RoPE.__call__ expects (..., seq_length, head_dim), so vmap over heads. + if input_positions is None: + apply = jax.vmap(rope, in_axes=-2, out_axes=-2) + else: + apply = jax.vmap(rope, in_axes=(-2, -1), out_axes=(-2, -1)) + query = apply(query, input_positions) + key = apply(key, input_positions) + return dot_product_attention(query, key, value, **kwargs) + + class MultiHeadAttention(Module): """Multi-head attention. @@ -588,6 +710,7 @@ def __call__( is_causal=False, out_sharding = None, qkv_sharding = None, + input_positions: Array | None = None ): """Applies multi-head dot product attention on the input data. @@ -624,6 +747,9 @@ def __call__( the output linear layer for the output arrays. qkv_sharding: Optional sharding specification to pass to the QKV linear layers for the output arrays. + input_positions: Optional position indices of shape + ``[batch_sizes..., length]`` forwarded to the ``attention_fn``. + Used by :func:`dot_product_attention_with_rope` for RoPE. Returns: output of shape `[batch_sizes..., length, features]`. @@ -736,6 +862,9 @@ def __call__( dropout_rng = None # apply attention + attn_kwargs = {} + if input_positions is not None: + attn_kwargs["input_positions"] = input_positions x = self.attention_fn( query, key, @@ -748,7 +877,8 @@ def __call__( dtype=self.dtype, precision=self.precision, module=self if sow_weights else None, - is_causal=is_causal + is_causal=is_causal, + **attn_kwargs ) # back to the original inputs dimensions out = self.out(x, out_sharding=out_sharding) diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 167fbf8cf..82aa52344 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -383,7 +383,41 @@ def _run(m): np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3) -if __name__ == '__main__': - absltest.main() +class TestRoPE(absltest.TestCase): + + def test_relative_position_invariance(self): + """Dot product of RoPE-rotated vectors depends only on relative position. + + For any offset d, should equal + . + """ + head_dim = 64 + max_len = 128 + rope = nnx.RoPE() + + k1, k2 = jax.random.split(jax.random.key(7)) + q_vec = jax.random.normal(k1, (head_dim,)) + k_vec = jax.random.normal(k2, (head_dim,)) + + # Place q at position i and k at position j, then shift both by d. + i, j = 5, 12 + d = 37 + def rope_at(vec, pos): + # Build a dummy sequence long enough, apply RoPE, extract one position. + seq = jnp.zeros((pos + 1, head_dim)).at[pos].set(vec) + return rope(seq)[pos] + q_rot = rope_at(q_vec, i) + k_rot = rope_at(k_vec, j) + dot_original = jnp.dot(q_rot, k_rot) + + q_rot_shifted = rope_at(q_vec, i + d) + k_rot_shifted = rope_at(k_vec, j + d) + dot_shifted = jnp.dot(q_rot_shifted, k_rot_shifted) + + np.testing.assert_allclose(dot_original, dot_shifted, atol=1e-5) + + +if __name__ == '__main__': + absltest.main()