alternative: drop model-aware flops, use n^2 proxy#2726
Draft
samsja wants to merge 1 commit into
Draft
Conversation
Centralize per-bin workload math in a new prime_rl.trainer.cost_model module that returns sum(n*n for n in seqlens). After FFD packs bins to ~max_seq_len, the linear part of the forward is approximately constant across bins, so attention (Sum n^2) dominates the inter-bin variance. This drops the model-aware FLOPs estimator (calculate_packing_fwd_flops and friends), the balance_by_flops config flag, and all model.config / flops_config plumbing through the packer. If a future model class shows measured wallclock skew that the n^2 proxy mis-ranks, add a model-aware estimator inside cost_model.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed simplification on top of #2723.
Idea
After FFD packs bins to ~
max_seq_len, the linear part of the forward (QKV proj, attn-out, FFN, LM head) is approximately constant across bins. Inter-bin work variance is dominated by attention, which isSum n^2per bin (per-sequence O(n^2) under packed / cu_seqlens attention).Replacing the model-aware FLOPs estimator with
bin_cost(seqlens) = sum(n*n for n in seqlens)ranks bins the same way for standard MHA/GQA setups, with none of the HF-config introspection.What changes vs #2723
prime_rl/trainer/cost_model.py(~15 lines): one functionbin_cost.calculate_packing_fwd_flops,_calculate_qkv_projection_flops,_calculate_attention_flops,_calculate_layer_flops,_is_mla,get_packing_flops_config.flops_configparameter onpacked_samples_into_micro_bs,_MicroBatchBin.workload,_MicroBatchBin._sample_workload,_MicroBatchBin.split_by_workload,_expand_bins_by_splitting,_distribute_group,prepare_batch.balance_by_flopsflag inDataLoaderConfig.model.configpassing into the packer /DataLoader/setup_packer.Algorithmic content kept intact from #2723:
sequence_lengthsfield onMicroBatch_MicroBatchBin+_materialize_binpacker refactor_expand_bins_by_splittingfor thefewer bins than DP workerscaseWhy
batch.pystops importing model knowledge — packer becomes a pure data-shape transform. If a model class is later shown to need a richer cost (measured wallclock regression on real workload, not theoretical), it goes in one file (cost_model.py), behind one function.The randomized packing-invariants test (
test_randomized_packing_invariants, 80 cases) and the_pairs_long_and_short_workloads/_splits_heaviest_binscheduling tests all pass.🤖 Generated with Claude Code