diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index b336d855ea..00f4e07deb 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -450,9 +450,6 @@ class DataLoaderConfig(BaseConfig): fake: FakeDataLoaderConfig | None = None """Use a fake data loader sampling random micro-batches (for debugging).""" - balance_by_flops: bool = True - """Balance packed micro-batches across data-parallel ranks using model-estimated FLOPs.""" - class BaseWeightBroadcastConfig(BaseConfig): pass diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index a5d7c0c263..8aa90160d0 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -1,8 +1,8 @@ import copy import heapq from dataclasses import dataclass, field -from typing import Any +from prime_rl.trainer.cost_model import bin_cost from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample ROUTED_EXPERTS_DTYPE_ITEMSIZE = { @@ -12,122 +12,6 @@ } -def get_packing_flops_config(model_config: Any) -> Any: - """Return the text config used for model-aware packing FLOP estimates.""" - return getattr(model_config, "text_config", model_config) - - -def _is_mla(config: Any) -> bool: - return bool(getattr(config, "multi_latent_attention", False) or hasattr(config, "q_lora_rank")) - - -def _calculate_qkv_projection_flops(config: Any, seqlen: int) -> float: - hidden_size = config.hidden_size - num_attention_heads = config.num_attention_heads - kv_channels = getattr(config, "kv_channels", getattr(config, "head_dim", hidden_size // num_attention_heads)) - is_mla = _is_mla(config) - if is_mla and getattr(config, "q_lora_rank", None) is not None: - q_flops = ( - 2 - * seqlen - * config.q_lora_rank - * ( - hidden_size - + num_attention_heads * (getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0)) - ) - ) - else: - q_head_dim = ( - getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0) if is_mla else kv_channels - ) - q_flops = 2 * seqlen * hidden_size * num_attention_heads * q_head_dim - - if is_mla and getattr(config, "kv_lora_rank", None) is not None: - kv_flops = ( - 2 - * seqlen - * ( - config.kv_lora_rank - * ( - hidden_size - + num_attention_heads * (getattr(config, "qk_head_dim", 0) + getattr(config, "v_head_dim", 0)) - ) - + hidden_size * getattr(config, "qk_pos_emb_head_dim", 0) - ) - ) - else: - num_query_groups = getattr( - config, "num_query_groups", getattr(config, "num_key_value_heads", num_attention_heads) - ) - kv_flops = 4 * seqlen * hidden_size * num_query_groups * kv_channels - return q_flops + kv_flops - - -def _calculate_attention_flops(config: Any, seqlen: int) -> float: - num_attention_heads = config.num_attention_heads - kv_channels = getattr(config, "kv_channels", getattr(config, "head_dim", config.hidden_size // num_attention_heads)) - if _is_mla(config): - flops = ( - num_attention_heads - * seqlen - * seqlen - * (getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0)) - ) - flops += num_attention_heads * seqlen * seqlen * getattr(config, "v_head_dim", kv_channels) - else: - flops = 2 * num_attention_heads * seqlen * seqlen * kv_channels - return flops - - -def _calculate_layer_flops(config: Any, seqlen: int, ffn_hidden_size: int) -> float: - hidden_size = config.hidden_size - return ( - _calculate_qkv_projection_flops(config, seqlen) - + _calculate_attention_flops(config, seqlen) - + 2 * seqlen * hidden_size * hidden_size - + 6 * seqlen * hidden_size * ffn_hidden_size - ) - - -def calculate_packing_fwd_flops(seqlens: list[int], config: Any) -> float: - """Model-aware forward FLOP estimate copied in spirit from slime.""" - num_experts = getattr(config, "num_experts", getattr(config, "n_routed_experts", None)) - dense_ffn = getattr( - config, "ffn_hidden_size", getattr(config, "intermediate_size", getattr(config, "moe_intermediate_size", 0)) - ) - if num_experts is None: - num_dense_layers = config.num_hidden_layers - num_moe_layers = 0 - moe_ffn = dense_ffn - else: - moe_layer_freq = getattr(config, "moe_layer_freq", None) - if isinstance(moe_layer_freq, list): - num_dense_layers = sum(1 for freq in moe_layer_freq if freq == 0) - num_moe_layers = sum(1 for freq in moe_layer_freq if freq > 0) - elif isinstance(moe_layer_freq, int): - num_dense_layers = sum(1 for i in range(config.num_hidden_layers) if i % moe_layer_freq != 0) - num_moe_layers = config.num_hidden_layers - num_dense_layers - elif getattr(config, "first_k_dense_replace", None) is not None: - num_dense_layers = config.first_k_dense_replace - num_moe_layers = config.num_hidden_layers - num_dense_layers - else: - num_dense_layers = 0 - num_moe_layers = config.num_hidden_layers - - routed_topk = getattr(config, "moe_router_topk", getattr(config, "num_experts_per_tok", 1)) - moe_ffn = getattr(config, "moe_ffn_hidden_size", getattr(config, "moe_intermediate_size", dense_ffn)) - moe_ffn = moe_ffn * routed_topk + (getattr(config, "moe_shared_expert_intermediate_size", None) or 0) - - total_flops = 0.0 - for seqlen in seqlens: - if num_dense_layers > 0: - total_flops += _calculate_layer_flops(config, seqlen, dense_ffn) * num_dense_layers - if num_moe_layers > 0: - total_flops += _calculate_layer_flops(config, seqlen, moe_ffn) * num_moe_layers - total_flops += 2 * seqlen * config.hidden_size * config.vocab_size - return total_flops - - def _copy_routed_experts(routed_experts: RoutedExperts) -> RoutedExperts: return RoutedExperts( data=routed_experts.data, @@ -277,23 +161,17 @@ def add(self, lora_idx: int, sample: MicroBatch) -> None: self.samples.append((lora_idx, sample)) self.length += len(sample.input_ids) - def workload(self, flops_config: Any | None) -> float: - if flops_config is None: - return self.length - return calculate_packing_fwd_flops([len(sample.input_ids) for _, sample in self.samples], flops_config) - - def _sample_workload(self, sample: MicroBatch, flops_config: Any | None) -> float: - if flops_config is None: - return len(sample.input_ids) - return calculate_packing_fwd_flops([len(sample.input_ids)], flops_config) + @property + def workload(self) -> int: + return bin_cost(len(sample.input_ids) for _, sample in self.samples) - def split_by_workload(self, flops_config: Any | None) -> tuple["_MicroBatchBin", "_MicroBatchBin"]: + def split_by_workload(self) -> tuple["_MicroBatchBin", "_MicroBatchBin"]: left: list[tuple[int, MicroBatch]] = [] right: list[tuple[int, MicroBatch]] = [] - left_workload = 0.0 - right_workload = 0.0 - for lora_idx, sample in sorted(self.samples, key=lambda x: -self._sample_workload(x[1], flops_config)): - sample_workload = self._sample_workload(sample, flops_config) + left_workload = 0 + right_workload = 0 + for lora_idx, sample in sorted(self.samples, key=lambda x: -len(x[1].input_ids) ** 2): + sample_workload = len(sample.input_ids) ** 2 if left_workload <= right_workload: left.append((lora_idx, sample)) left_workload += sample_workload @@ -472,17 +350,15 @@ def _improve_partitions_by_swapping(weights: list[float], partitions: list[list[ ) -def _expand_bins_by_splitting(bins: list[_MicroBatchBin], target_count: int, flops_config: Any | None) -> None: +def _expand_bins_by_splitting(bins: list[_MicroBatchBin], target_count: int) -> None: while len(bins) < target_count: candidates = [ - (bin_content.workload(flops_config), idx) - for idx, bin_content in enumerate(bins) - if len(bin_content.samples) > 1 + (bin_content.workload, idx) for idx, bin_content in enumerate(bins) if len(bin_content.samples) > 1 ] if not candidates: break _, idx = max(candidates) - left, right = bins[idx].split_by_workload(flops_config) + left, right = bins[idx].split_by_workload() bins[idx] = left bins.append(right) @@ -492,7 +368,6 @@ def packed_samples_into_micro_bs( max_seq_len: int, num_loras: int, num_train_workers: int, - flops_config: Any | None = None, ) -> list[MicroBatch]: """ Pack samples into micro_batch efficiently. @@ -521,27 +396,18 @@ def packed_samples_into_micro_bs( ((len(bins) + num_train_workers - 1) // num_train_workers) * num_train_workers, num_train_workers, ) - _expand_bins_by_splitting(bins, target_count, flops_config) + _expand_bins_by_splitting(bins, target_count) return [_materialize_bin(bin_content, num_loras) for bin_content in bins] -def _distribute_group( - group: list[MicroBatch], - num_train_workers: int, - flops_config: Any | None, -) -> list[list[MicroBatch]]: +def _distribute_group(group: list[MicroBatch], num_train_workers: int) -> list[list[MicroBatch]]: assert len(group) % num_train_workers == 0, "Number of micro batches is not divisible by number of data ranks" if not group: return [[] for _ in range(num_train_workers)] if len(group) >= num_train_workers: - weights = [ - calculate_packing_fwd_flops(micro_batch.sequence_lengths, flops_config) - if flops_config is not None - else len(micro_batch.input_ids) - for micro_batch in group - ] + weights = [bin_cost(micro_batch.sequence_lengths) for micro_batch in group] partitions = _balanced_partitions(weights, num_train_workers) partitions = _improve_partitions_by_swapping(weights, partitions) return [[group[i] for i in partition] for partition in partitions] @@ -620,7 +486,6 @@ def prepare_batch( idxs: list[int], num_loras: int, pad_to_multiple_of: int = 1, - flops_config: Any | None = None, ) -> list[list[MicroBatch]]: """ Prepare a batch of problems for each GPU. Each batch is a list of micro batches. @@ -633,13 +498,7 @@ def prepare_batch( """ all_samples = [(idx, prepare_sample(rollout, seq_len)) for idx, rollout in zip(idxs, rollouts)] - micro_batches = packed_samples_into_micro_bs( - all_samples, - seq_len, - num_loras, - num_train_workers, - flops_config=flops_config, - ) + micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras, num_train_workers) micro_batches = [pad_micro_batch(micro_batch, pad_to_multiple_of) for micro_batch in micro_batches] # Separate by modality so each step index has uniform modality across all ranks @@ -652,11 +511,7 @@ def prepare_batch( batches_per_gpu: list[list[MicroBatch]] = [[] for _ in range(num_train_workers)] for group in (mm_batches, text_batches): - group_batches_per_gpu = _distribute_group( - group, - num_train_workers, - flops_config, - ) + group_batches_per_gpu = _distribute_group(group, num_train_workers) for worker_idx, worker_batches in enumerate(group_batches_per_gpu): batches_per_gpu[worker_idx].extend(worker_batches) diff --git a/src/prime_rl/trainer/cost_model.py b/src/prime_rl/trainer/cost_model.py new file mode 100644 index 0000000000..cc8924d19c --- /dev/null +++ b/src/prime_rl/trainer/cost_model.py @@ -0,0 +1,18 @@ +"""Per-sequence forward-cost proxy used to balance packed micro-batches. + +In a packed micro-batch with sequence-masked attention, each sequence +attends only within itself, so attention is O(n^2) per sequence while +linear ops (QKV proj, FFN, attn-out) are O(n). FFD packs bins to +~max_seq_len, so the linear term is approximately constant across bins +and inter-bin work variance is dominated by attention. + +`bin_cost` returns just the n^2 term. It ranks bins correctly for +standard MHA / GQA setups. Add a model-aware estimator here only if a +specific model shows measured wallclock skew that this proxy misses. +""" + +from collections.abc import Iterable + + +def bin_cost(seqlens: Iterable[int]) -> int: + return sum(n * n for n in seqlens) diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index a1fd8bc80a..ba836b8dc4 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -167,7 +167,6 @@ def __init__( seq_len: int, pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, - flops_config, config: TransportConfig, ): self.world = get_world() @@ -179,7 +178,6 @@ def __init__( tokenizer=tokenizer, transport_config=config, pad_to_multiple_of=pad_to_multiple_of, - flops_config=flops_config, start_step=start_step, ) diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index 501205e578..cf9dcfa02e 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -7,7 +7,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.trainer.batch import get_packing_flops_config, prepare_batch +from prime_rl.trainer.batch import prepare_batch from prime_rl.trainer.runs import get_multi_run_manager from prime_rl.transport import ( MicroBatch, @@ -32,7 +32,6 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, - flops_config=None, start_step: int = 0, ): self.logger = get_logger() @@ -41,7 +40,6 @@ def __init__( self.seq_len = seq_len self.pad_to_multiple_of = pad_to_multiple_of self.tokenizer = tokenizer - self.flops_config = get_packing_flops_config(flops_config) if flops_config is not None else None self.receiver = setup_training_batch_receiver(config) shutil.rmtree(get_rollout_dir(self.multi_run_manager.output_dir), ignore_errors=True) self.sender: MicroBatchSender = setup_micro_batch_sender( @@ -86,10 +84,9 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, - flops_config=None, start_step: int = 0, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, flops_config, start_step) + super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run" def pack(self): @@ -113,7 +110,6 @@ def pack(self): num_train_workers=self.dp_world_size, idxs=[0] * len(batch.examples), num_loras=self.multi_run_manager.max_runs, - flops_config=self.flops_config, ) self.sender.send(micro_batch_grid) @@ -127,10 +123,9 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, - flops_config=None, start_step: int = 0, ): - super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, flops_config, start_step) + super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -332,7 +327,6 @@ def pack(self): num_train_workers=self.dp_world_size, idxs=[run_idx] * len(run_samples), num_loras=self.multi_run_manager.max_runs, - flops_config=self.flops_config, ) # Merge into combined grid for worker_idx, worker_batches in enumerate(run_micro_batch_grid): @@ -347,15 +341,10 @@ def setup_packer( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, transport_config: TransportConfig, - flops_config=None, start_step: int = 0, ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: - return SinglePacker( - dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, flops_config, start_step - ) + return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) else: - return MultiPacker( - dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, flops_config, start_step - ) + return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 7e363eb73c..b3c5e6b6f5 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -236,7 +236,6 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.model.seq_len, config.model.cp, tokenizer, - model.config if config.data.balance_by_flops else None, config.rollout_transport, ) diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index 6d8aed1a8a..1930fd0572 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -1,10 +1,7 @@ -from types import SimpleNamespace - import numpy as np import pytest from prime_rl.trainer.batch import ( - calculate_packing_fwd_flops, pad_micro_batch, prepare_batch, prepare_sample, @@ -63,26 +60,10 @@ def _flatten_batches(batches_per_gpu): return [batch for worker_batches in batches_per_gpu for batch in worker_batches] -def _worker_token_sums(batches_per_gpu) -> list[int]: - return [sum(len(batch.input_ids) for batch in worker_batches) for worker_batches in batches_per_gpu] - - def _has_loss_tokens(batch: MicroBatch) -> bool: return any(batch.loss_mask) -def make_flops_config(): - return SimpleNamespace( - hidden_size=16, - num_attention_heads=2, - num_key_value_heads=2, - vocab_size=32, - intermediate_size=64, - num_hidden_layers=2, - head_dim=8, - ) - - def test_randomized_packing_invariants(): rng = np.random.default_rng(0) @@ -92,7 +73,6 @@ def test_randomized_packing_invariants(): num_samples = int(rng.integers(1, 65)) lengths = [int(x) for x in rng.integers(1, seq_len + 1, size=num_samples)] examples = [make_sized_training_example(length, env_name=f"env-{case_idx}") for length in lengths] - flops_config = make_flops_config() if case_idx % 2 == 0 else None batches_per_gpu = prepare_batch( rollouts=examples, @@ -100,7 +80,6 @@ def test_randomized_packing_invariants(): num_train_workers=num_train_workers, idxs=[0] * len(examples), num_loras=1, - flops_config=flops_config, ) flat_batches = _flatten_batches(batches_per_gpu) real_batches = [batch for batch in flat_batches if _has_loss_tokens(batch)] @@ -190,23 +169,8 @@ def test_split_to_align_avoids_dummy_micro_batches(): assert len(_flatten_batches(batches_per_gpu)) == 4 -def test_pack_first_then_balance_distributes_micro_batches_by_tokens_without_model_config(): - examples = [make_sized_training_example(length) for length in [100, 90, 80, 70]] - - balanced = prepare_batch( - rollouts=examples, - seq_len=100, - num_train_workers=2, - idxs=[0] * len(examples), - num_loras=1, - ) - - assert _worker_token_sums(balanced) == [170, 170] - - -def test_flop_aware_balancing_pairs_long_and_short_sequence_workloads(): +def test_pack_first_then_balance_pairs_long_and_short_sequence_workloads(): examples = [make_sized_training_example(length) for length in [32, 32, 16, 16, 16, 16]] - flops_config = make_flops_config() balanced = prepare_batch( rollouts=examples, @@ -214,15 +178,13 @@ def test_flop_aware_balancing_pairs_long_and_short_sequence_workloads(): num_train_workers=2, idxs=[0] * len(examples), num_loras=1, - flops_config=flops_config, ) assert sorted([sorted(batch.sequence_lengths) for batch in balanced[0]]) == [[16, 16], [32]] assert sorted([sorted(batch.sequence_lengths) for batch in balanced[1]]) == [[16, 16], [32]] - assert calculate_packing_fwd_flops([32], flops_config) > calculate_packing_fwd_flops([16, 16], flops_config) -def test_flop_aware_split_to_align_splits_heaviest_flop_bin(): +def test_split_to_align_splits_heaviest_bin(): examples = [make_sized_training_example(length) for length in [20, 18, 9, 9, 8, 8, 8]] batches_per_gpu = prepare_batch( @@ -231,13 +193,13 @@ def test_flop_aware_split_to_align_splits_heaviest_flop_bin(): num_train_workers=4, idxs=[0] * len(examples), num_loras=1, - flops_config=make_flops_config(), ) real_batches = [batch for batch in _flatten_batches(batches_per_gpu) if _has_loss_tokens(batch)] assert len(real_batches) == 4 assert sorted(length for batch in real_batches for length in batch.sequence_lengths) == [8, 8, 8, 9, 9, 18, 20] - assert sum(len(batch.sequence_lengths) > 1 for batch in real_batches) == 3 + # The two longest samples land alone after splitting the heaviest bin by n^2 workload. + assert any(batch.sequence_lengths == [20] for batch in real_batches) def test_prepare_sample_truncates_routed_experts(): diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index 291f88263f..fcdee7a843 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -102,13 +102,6 @@ def test_defaults(): assert config.variant.alpha == 0.1 -def test_trainer_data_balance_by_flops_can_be_disabled(): - assert cli(TrainerConfig, args=[]).data.balance_by_flops is True - - config = cli(TrainerConfig, args=["--data.balance-by-flops", "false"]) - assert config.data.balance_by_flops is False - - def test_toml_partial_nested_override(tmp_path): """Partially overriding a nested model preserves unset field defaults.""" write_toml(tmp_path / "cfg.toml", {"nested": {"lr": 3e-4}})