Skip to content

alternative: drop model-aware flops, use n^2 proxy#2726

Draft
samsja wants to merge 1 commit into
seq_packingfrom
sami/simplify-packing-cost-model
Draft

alternative: drop model-aware flops, use n^2 proxy#2726
samsja wants to merge 1 commit into
seq_packingfrom
sami/simplify-packing-cost-model

Conversation

@samsja

@samsja samsja commented Jun 7, 2026

Copy link
Copy Markdown
Member

Proposed simplification on top of #2723.

Idea

After FFD packs bins to ~max_seq_len, the linear part of the forward (QKV proj, attn-out, FFN, LM head) is approximately constant across bins. Inter-bin work variance is dominated by attention, which is Sum n^2 per bin (per-sequence O(n^2) under packed / cu_seqlens attention).

Replacing the model-aware FLOPs estimator with bin_cost(seqlens) = sum(n*n for n in seqlens) ranks bins the same way for standard MHA/GQA setups, with none of the HF-config introspection.

What changes vs #2723

  • New prime_rl/trainer/cost_model.py (~15 lines): one function bin_cost.
  • Deletes calculate_packing_fwd_flops, _calculate_qkv_projection_flops, _calculate_attention_flops, _calculate_layer_flops, _is_mla, get_packing_flops_config.
  • Deletes flops_config parameter on packed_samples_into_micro_bs, _MicroBatchBin.workload, _MicroBatchBin._sample_workload, _MicroBatchBin.split_by_workload, _expand_bins_by_splitting, _distribute_group, prepare_batch.
  • Deletes balance_by_flops flag in DataLoaderConfig.
  • Deletes model.config passing into the packer / DataLoader / setup_packer.
  • Net diff vs improve sequence packing #2723: 44 insertions, 233 deletions.

Algorithmic content kept intact from #2723:

  • sequence_lengths field on MicroBatch
  • _MicroBatchBin + _materialize_bin packer refactor
  • _expand_bins_by_splitting for the fewer bins than DP workers case
  • KK + swap-improvement balanced partitioning
  • Text/multimodal balanced separately

Why

batch.py stops importing model knowledge — packer becomes a pure data-shape transform. If a model class is later shown to need a richer cost (measured wallclock regression on real workload, not theoretical), it goes in one file (cost_model.py), behind one function.

The randomized packing-invariants test (test_randomized_packing_invariants, 80 cases) and the _pairs_long_and_short_workloads / _splits_heaviest_bin scheduling tests all pass.

🤖 Generated with Claude Code

Centralize per-bin workload math in a new prime_rl.trainer.cost_model
module that returns sum(n*n for n in seqlens). After FFD packs bins to
~max_seq_len, the linear part of the forward is approximately constant
across bins, so attention (Sum n^2) dominates the inter-bin variance.

This drops the model-aware FLOPs estimator (calculate_packing_fwd_flops
and friends), the balance_by_flops config flag, and all model.config /
flops_config plumbing through the packer.

If a future model class shows measured wallclock skew that the n^2
proxy mis-ranks, add a model-aware estimator inside cost_model.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant