recurrent : support equal splits for recurrent-state rollback#25004
Open
arielgindi wants to merge 1 commit into
Open
recurrent : support equal splits for recurrent-state rollback#25004arielgindi wants to merge 1 commit into
arielgindi wants to merge 1 commit into
Conversation
Resolves the [TAG_RECURRENT_ROLLBACK_SPLITS] TODO ("recurrent state rollback
does not support equal splits") in llama_memory_recurrent/hybrid::init_batch.
Speculative decoding on hybrid recurrent models (e.g. Qwen3.6 Gated-DeltaNet)
keeps per-sequence recurrent-state snapshots (n_rs_seq) so a rejected draft can
be rolled back. Until now those sequences were split one-per-ubatch (split_seq),
so concurrent requests ran the recurrent graph serially. This enables equal
splits (split_equal) for the rollback case, so concurrent sequences share a
single graph pass (~269 -> ~425 tok/s at --parallel 5 on Qwen3.6-35B-A3B with
draft-mtp; single sequence unchanged).
Equal splits can leave only n_written = min(n_seq_tokens, K) fresh snapshot
planes per ubatch, so to keep rollback correct:
- A rolling-history carry (build_rs_rollback_carry + s_copy_hist) carries the
deeper snapshot planes forward across ubatches. The writers assemble all K
planes as a single op result and write the cache once. Carry slot 0 reuses the
current input state already gathered by build_rs() (rather than re-reading
plane 0 from the cache, which build_rs() may have overwritten via its
extra-cell write), and deeper slots are gathered from the cache.
- A per-sequence valid depth (rs_valid_depth) bounds how deep a rollback can
reach after entering on a non-zero plane; seq_rm refuses (and accumulates onto)
a rollback deeper than the kept history. s_copy() consumes the rollback index
for all owners of a shared cell. seq_cp() mirrors/clears rollback metadata, and
partial seq_rm on a shared cell is refused. Extra (bracketed) cells materialize
their current state on relocation and reset their rollback index.
Also: 64-bit guard on the int32 snapshot row index, and shape asserts on the
assembled snapshot tensor.
Extends test-recurrent-state-rollback with multi-sequence and active-gap
rollback cases.
5e05f71 to
7ff2fd8
Compare
|
Hi @arielgindi, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This resolves the
[TAG_RECURRENT_ROLLBACK_SPLITS]TODO inllama_memory_recurrent::init_batch:recurrent-state rollback (used by speculative decoding on hybrid models) previously forced
split_seq, processing one sequence per ubatch. As a result the recurrent graph ran once persequence and concurrent speculative decoding did not scale — under load it could be slower than
non-speculative decoding.
This PR allows the rollback path to use
split_equal, so concurrent sequences share a singlegraph pass while keeping the per-sequence rollback snapshots correct.
Scope is intentionally narrow: it touches only the recurrent/hybrid memory and the delta-net
graph builder. There are no public API additions, no server-logic changes, and dense /
non-speculative paths are untouched. It is CPU/graph-side only — no CUDA changes.
Motivation
Hybrid recurrent models (Gated-DeltaNet families, e.g. Qwen3-Next / Qwen3.6) keep a rolling
recurrent state rather than per-token KV. To support speculative decoding, the memory keeps
n_rs_seqrollback snapshot planes per sequence so a rejected draft can restore the prior state.That machinery was only correct when each sequence owned a whole ubatch, hence the
split_seqrestriction and the standing TODO.
What changed
init_batch(recurrent + hybrid): usesplit_equal(n_ubatch, true)for the rollback case(
n_rs_seq > 0) instead ofsplit_seq, removing the TODO. The sequential equal split co-batchesconsecutive active seq_ids; coupled-sequence batches are unsupported under rollback and fall back
to
FAILED_PREPAREvia the empty split.Rolling-history carry (
build_rs_rollback_carry+ news_copy_histgraph input): with equalsplits a step only regenerates the most-recent planes, so the deeper snapshot planes are shifted
forward each ubatch (a shift-register), keeping the invariant that plane
rholds the statertokens before the new tail. The carried planes and the freshly-written planes are assembled and
committed to the cache in a single
ggml_cpy, so the carry's read of the old cache is orderedbefore the write via the graph's data dependency (not node insertion order). Carry slot 0 reuses
the already-gathered current state rather than re-reading plane 0, which a relocated extra cell's
write may have touched.
Rollback-metadata bookkeeping in the recurrent memory (
seq_rmis a public memory API and maybe called with arbitrary rollback depths):
rs_valid_depthtracks the deepest still-valid snapshot plane;seq_rmaccumulates and boundsthe requested rollback against it (a deeper rollback is the caller's responsibility to restore
from a checkpoint). It is maintained in lockstep with
rs_idxacrossprepare()(saved/restoredaround the placement dry-run),
clear(),seq_cp(), andseq_keep().s_copyconsumes the rollback index for all owners of a shared cell and aborts if they disagree.rollback depth, since
build_rscopies only plane 0 for them.Performance
Setup
UD-IQ3_XXSquant (~14 GB). Self-speculative MTP draft head built into the GGUF (
n_rs_seq == draft n_max),--spec-type draft-mtp --spec-draft-n-max 1.llama-server -c 8192 --parallel N. Aggregate wall-clock decode throughput(total generated tokens / wall time, all N slots decoding in lockstep from a shared cached prompt
Result — aggregate decode throughput at
--parallel 5:Full GPU offload (MoE on GPU):
--parallel 1--parallel 5MoE partly on CPU (
--n-cpu-moe 16) — the config required to run a 35B MoE on 16 GB:--parallel 1--parallel 5Under master,
--parallel 5 ≈ --parallel 1in both configs —split_seqserializes theconcurrent sequences through one recurrent forward each (so each of N streams runs at ~1/N speed and
aggregate throughput is flat). This PR batches the N sequences into one ubatch, restoring concurrency
scaling: 1.72× on full GPU, 1.56× with CPU-offloaded MoE, at 5 concurrent sequences. Single-
sequence throughput is within noise of master.
Testing
Extended
tests/test-recurrent-state-rollback(all sections pass):so the gather/relocation path runs in one ubatch; the rolled-back result matches a clean reference;
n_ubatch< prompt length); therolled-back replay matches a single-ubatch reference — this is what exercises the rolling carry
across ubatch boundaries (the PR's core);
seq_rmsurvives a decodethat follows a pending rollback, guarding the
prepare()/apply()depth accounting.Also exercised end-to-end under
--parallel 5with MTP speculative decoding (including reasoning/"thinking" generation and prompt-cache reuse across turns): no asserts/aborts; output matches the
serial path modulo expected batch-size floating-point nondeterminism.
Notes for reviewers
The metadata-correctness changes (
rs_valid_depthlifetime acrossprepare/clear/seq_keep,shared-cell index consumption, extra-cell materialization) are needed because the equal split now
puts multiple rolled-back sequences in one ubatch. Happy to split them into a separate PR if you'd
prefer the
split_equalenablement to land on its own.