Skip to content

Add RoPE embeddings#5481

Open
samanklesaria wants to merge 1 commit into
google:mainfrom
samanklesaria:rope
Open

Add RoPE embeddings#5481
samanklesaria wants to merge 1 commit into
google:mainfrom
samanklesaria:rope

Conversation

@samanklesaria
Copy link
Copy Markdown
Collaborator

This PR adds support for RoPE. Specifically, a new function dot_product_attention_with_rope can be used as the attention_fn argument for nnx.MultiHeadAttention.

Comment thread flax/nnx/nn/attention.py
Comment thread flax/nnx/nn/attention.py Outdated
Comment on lines +363 to +364
self.cos_cached = nnx.Variable(jnp.cos(freqs_outer).astype(dtype))
self.sin_cached = nnx.Variable(jnp.sin(freqs_outer).astype(dtype))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multiple JAX repositories I saw that sin and cos are constructed from segment_pos (absolute token positions):

The idea is that if the input sequence is not packed, but just padded, our implementation would mostly work as expected. In case of packed sequence where multiple sentences are inserted in the same sequence:

[<bos>, 1, 2, 3, <eos>, <bos>, 4, 5, <eos>, <pad>, <pad>]

then the input positions (segment_pos) would be:

[0, 1, 2, 3, 4, 0, 1, 2, 3, 0, 0]

so, RoPE may be computed differently.
On the other hand, current MHA.__call__ does not accept any positions arg, so we can't pass it to RoPE...

In PyTorch, basic implementation does something similar to your implementation, but cos and sin cached from the input x.

Copy link
Copy Markdown
Collaborator Author

@samanklesaria samanklesaria Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My implementation was based on the one in equinox, which doesn't pass explicit positions.

  • We could give MHA a new optional positions argument which we thread through to the attention_fn. But this could break users' custom attention_fn implementations if they weren't expecting the argument.
  • We could make a MHA subclass with a call method that accepts a positions argument. Say, PackedMHA? This is a little more ugly, but wouldn't be a breaking change.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the following?

class MHA:

  def __call__(self, ..., input_positions: Array | None = None):
    ...
    attn_kwargs = {}
    if input_positions is not None:
       attn_kwargs["input_positions"] = input_positions
    x = self.attention_fn(
      query,
      key,
      value,
      mask=mask,
      dropout_rng=dropout_rng,
      dropout_rate=self.dropout_rate,
      broadcast_dropout=self.broadcast_dropout,
      deterministic=deterministic,
      dtype=self.dtype,
      precision=self.precision,
      module=self if sow_weights else None,
      is_causal=is_causal,
      **attn_kwargs,
    )


def dot_product_attention_with_rope(..., rope, input_positions: Array | None = None, **kwargs)
  # handle properly input_positions is None
  # input_positions: (B, S)
  apply = jax.vmap(rope, in_axes=(-2, -1), out_axes=(-2, -1))
  query = apply(query, input_positions)
  key = apply(key, input_positions)
  ...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that! So users that aren't already calling MHA with input positions won't have their custom attention_fn break!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still keep the cached sin and cos vectors for use when input_positions=None? Might be slightly faster for the non-packed case, as we wouldn't need to rebuild them. But the interface could be nicer if we just rebuilt them every time, so that the RoPE constructor wouldn't need max_seq_len and embedding_size arguments (dynamically getting them from the input x). What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can cache them after the first call ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we'd check if the input_positions is the same as the cached one at each call. If so, we'd use the cache (populated on the first call) and otherwise, we'd generate it dynamically?

To work in tree mode, we'd need to create the Variables for the cache at initialization. But then we could write to these variables on the first __call__.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're doing cross attention, things get trickier. The sequence lengths for the keys might not be the same as the lengths for the values. Caching a single value wouldn't account for this. Caching everything would break if the packing of each batch might be different.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, actually, not sure if seen RoPE for cross-attention, but it makes sense that we may need to have q_positions and kv_positions or even k_positions, v_positions. But it can become rather cumbersome the API finally

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up at a compromise API. If the user wants caching, they can specify the max size during initialization, like I had before. Otherwise, we don't cache, and compute the rotation matrices on the fly.

This also means we don't have to use QDD for the cache, which might become deprecated if we end up switching to hijax variables and QDD is removed. One less thing to worry about down the road.

@samanklesaria samanklesaria force-pushed the rope branch 2 times, most recently from 66c99da to 319550d Compare June 3, 2026 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants