Skip to content

recurrent : support equal splits for recurrent-state rollback#25004

Open
arielgindi wants to merge 1 commit into
ggml-org:masterfrom
arielgindi:mtp-parallel-splits
Open

recurrent : support equal splits for recurrent-state rollback#25004
arielgindi wants to merge 1 commit into
ggml-org:masterfrom
arielgindi:mtp-parallel-splits

Conversation

@arielgindi

Copy link
Copy Markdown

Summary

This resolves the [TAG_RECURRENT_ROLLBACK_SPLITS] TODO in llama_memory_recurrent::init_batch:
recurrent-state rollback (used by speculative decoding on hybrid models) previously forced
split_seq, processing one sequence per ubatch. As a result the recurrent graph ran once per
sequence and concurrent speculative decoding did not scale — under load it could be slower than
non-speculative decoding.

This PR allows the rollback path to use split_equal, so concurrent sequences share a single
graph pass while keeping the per-sequence rollback snapshots correct.

Scope is intentionally narrow: it touches only the recurrent/hybrid memory and the delta-net
graph builder. There are no public API additions, no server-logic changes, and dense /
non-speculative paths are untouched. It is CPU/graph-side only — no CUDA changes.

Motivation

Hybrid recurrent models (Gated-DeltaNet families, e.g. Qwen3-Next / Qwen3.6) keep a rolling
recurrent state rather than per-token KV. To support speculative decoding, the memory keeps
n_rs_seq rollback snapshot planes per sequence so a rejected draft can restore the prior state.
That machinery was only correct when each sequence owned a whole ubatch, hence the split_seq
restriction and the standing TODO.

What changed

  • init_batch (recurrent + hybrid): use split_equal(n_ubatch, true) for the rollback case
    (n_rs_seq > 0) instead of split_seq, removing the TODO. The sequential equal split co-batches
    consecutive active seq_ids; coupled-sequence batches are unsupported under rollback and fall back
    to FAILED_PREPARE via the empty split.

  • Rolling-history carry (build_rs_rollback_carry + new s_copy_hist graph input): with equal
    splits a step only regenerates the most-recent planes, so the deeper snapshot planes are shifted
    forward each ubatch (a shift-register), keeping the invariant that plane r holds the state r
    tokens before the new tail. The carried planes and the freshly-written planes are assembled and
    committed to the cache in a single ggml_cpy, so the carry's read of the old cache is ordered
    before the write via the graph's data dependency (not node insertion order). Carry slot 0 reuses
    the already-gathered current state rather than re-reading plane 0, which a relocated extra cell's
    write may have touched.

  • Rollback-metadata bookkeeping in the recurrent memory (seq_rm is a public memory API and may
    be called with arbitrary rollback depths):

    • rs_valid_depth tracks the deepest still-valid snapshot plane; seq_rm accumulates and bounds
      the requested rollback against it (a deeper rollback is the caller's responsibility to restore
      from a checkpoint). It is maintained in lockstep with rs_idx across prepare() (saved/restored
      around the placement dry-run), clear(), seq_cp(), and seq_keep().
    • s_copy consumes the rollback index for all owners of a shared cell and aborts if they disagree.
    • relocated "extra" cells (live seqs not active in the ubatch) materialize plane 0 and reset their
      rollback depth, since build_rs copies only plane 0 for them.
    • 64-bit guard on the snapshot row index (throws on overflow); shape asserts on the assembled tensor.

Performance

Setup

  • Model: Qwen3.6-35B-A3B (Gated-DeltaNet + attention + MoE, ~3B active params), UD-IQ3_XXS
    quant (~14 GB). Self-speculative MTP draft head built into the GGUF (n_rs_seq == draft n_max),
    --spec-type draft-mtp --spec-draft-n-max 1.
  • Hardware: NVIDIA GeForce RTX 5080 (16 GB), CUDA 13.1, release build.
  • Workload: llama-server -c 8192 --parallel N. Aggregate wall-clock decode throughput
    (total generated tokens / wall time, all N slots decoding in lockstep from a shared cached prompt
    • 256-token generation), so the metric is unaffected by per-request prefill or timing skew.

Result — aggregate decode throughput at --parallel 5:

Full GPU offload (MoE on GPU):

concurrency master this PR
--parallel 1 250 tok/s 230 tok/s
--parallel 5 249 tok/s 429 tok/s (1.72× vs master)

MoE partly on CPU (--n-cpu-moe 16) — the config required to run a 35B MoE on 16 GB:

concurrency master this PR
--parallel 1 136 tok/s 141 tok/s
--parallel 5 144 tok/s 225 tok/s (1.56× vs master)

Under master, --parallel 5 ≈ --parallel 1 in both configs — split_seq serializes the
concurrent sequences through one recurrent forward each (so each of N streams runs at ~1/N speed and
aggregate throughput is flat). This PR batches the N sequences into one ubatch, restoring concurrency
scaling: 1.72× on full GPU, 1.56× with CPU-offloaded MoE, at 5 concurrent sequences. Single-
sequence throughput is within noise of master.

Testing

Extended tests/test-recurrent-state-rollback (all sections pass):

  • checkpoint restore after a forced rollback (existing);
  • multi-sequence parallel rollback matches a single-sequence reference;
  • active-gap alias: an idle sequence is bracketed inside a consecutive active pair's cell range
    so the gather/relocation path runs in one ubatch; the rolled-back result matches a clean reference;
  • cross-ubatch carry: a prompt split across multiple ubatches (n_ubatch < prompt length); the
    rolled-back replay matches a single-ubatch reference — this is what exercises the rolling carry
    across ubatch boundaries (the PR's core);
  • rollback-depth bookkeeping: asserts the accept/refuse boundary of seq_rm survives a decode
    that follows a pending rollback, guarding the prepare()/apply() depth accounting.
# unit test (any hybrid recurrent gguf with rollback support; self-skips otherwise)
./bin/test-recurrent-state-rollback --model <model.gguf> -ngl 0

# throughput
./bin/llama-server -m <model.gguf> -ngl 99 --parallel 5 -c 8192 \
    --spec-type draft-mtp --spec-draft-n-max 1
# then N concurrent /v1/chat/completions; aggregate = total generated tokens / wall time

Also exercised end-to-end under --parallel 5 with MTP speculative decoding (including reasoning/
"thinking" generation and prompt-cache reuse across turns): no asserts/aborts; output matches the
serial path modulo expected batch-size floating-point nondeterminism.

Notes for reviewers

The metadata-correctness changes (rs_valid_depth lifetime across prepare/clear/seq_keep,
shared-cell index consumption, extra-cell materialization) are needed because the equal split now
puts multiple rolled-back sequences in one ubatch. Happy to split them into a separate PR if you'd
prefer the split_equal enablement to land on its own.

@github-actions github-actions Bot added model Model specific testing Everything test related labels Jun 25, 2026
Resolves the [TAG_RECURRENT_ROLLBACK_SPLITS] TODO ("recurrent state rollback
does not support equal splits") in llama_memory_recurrent/hybrid::init_batch.

Speculative decoding on hybrid recurrent models (e.g. Qwen3.6 Gated-DeltaNet)
keeps per-sequence recurrent-state snapshots (n_rs_seq) so a rejected draft can
be rolled back. Until now those sequences were split one-per-ubatch (split_seq),
so concurrent requests ran the recurrent graph serially. This enables equal
splits (split_equal) for the rollback case, so concurrent sequences share a
single graph pass (~269 -> ~425 tok/s at --parallel 5 on Qwen3.6-35B-A3B with
draft-mtp; single sequence unchanged).

Equal splits can leave only n_written = min(n_seq_tokens, K) fresh snapshot
planes per ubatch, so to keep rollback correct:

- A rolling-history carry (build_rs_rollback_carry + s_copy_hist) carries the
  deeper snapshot planes forward across ubatches. The writers assemble all K
  planes as a single op result and write the cache once. Carry slot 0 reuses the
  current input state already gathered by build_rs() (rather than re-reading
  plane 0 from the cache, which build_rs() may have overwritten via its
  extra-cell write), and deeper slots are gathered from the cache.

- A per-sequence valid depth (rs_valid_depth) bounds how deep a rollback can
  reach after entering on a non-zero plane; seq_rm refuses (and accumulates onto)
  a rollback deeper than the kept history. s_copy() consumes the rollback index
  for all owners of a shared cell. seq_cp() mirrors/clears rollback metadata, and
  partial seq_rm on a shared cell is refused. Extra (bracketed) cells materialize
  their current state on relocation and reset their rollback index.

Also: 64-bit guard on the int32 snapshot row index, and shape asserts on the
assembled snapshot tensor.

Extends test-recurrent-state-rollback with multi-sequence and active-gap
rollback cases.
@arielgindi arielgindi force-pushed the mtp-parallel-splits branch from 5e05f71 to 7ff2fd8 Compare June 25, 2026 12:41
@ggml-gh-bot

ggml-gh-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

Hi @arielgindi, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • AI-generated content: This project does not accept PRs, descriptions or commit messages that are fully or predominantly AI-generated. If you have used AI to assist you in writing code, please make sure to disclose that explicitly.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@arielgindi arielgindi closed this Jun 25, 2026
@arielgindi arielgindi reopened this Jun 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant