Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,6 @@ class DataLoaderConfig(BaseConfig):
fake: FakeDataLoaderConfig | None = None
"""Use a fake data loader sampling random micro-batches (for debugging)."""

balance_by_flops: bool = True
"""Balance packed micro-batches across data-parallel ranks using model-estimated FLOPs."""


class BaseWeightBroadcastConfig(BaseConfig):
pass
Expand Down
179 changes: 17 additions & 162 deletions src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import copy
import heapq
from dataclasses import dataclass, field
from typing import Any

from prime_rl.trainer.cost_model import bin_cost
from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample

ROUTED_EXPERTS_DTYPE_ITEMSIZE = {
Expand All @@ -12,122 +12,6 @@
}


def get_packing_flops_config(model_config: Any) -> Any:
"""Return the text config used for model-aware packing FLOP estimates."""
return getattr(model_config, "text_config", model_config)


def _is_mla(config: Any) -> bool:
return bool(getattr(config, "multi_latent_attention", False) or hasattr(config, "q_lora_rank"))


def _calculate_qkv_projection_flops(config: Any, seqlen: int) -> float:
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
kv_channels = getattr(config, "kv_channels", getattr(config, "head_dim", hidden_size // num_attention_heads))
is_mla = _is_mla(config)
if is_mla and getattr(config, "q_lora_rank", None) is not None:
q_flops = (
2
* seqlen
* config.q_lora_rank
* (
hidden_size
+ num_attention_heads * (getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0))
)
)
else:
q_head_dim = (
getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0) if is_mla else kv_channels
)
q_flops = 2 * seqlen * hidden_size * num_attention_heads * q_head_dim

if is_mla and getattr(config, "kv_lora_rank", None) is not None:
kv_flops = (
2
* seqlen
* (
config.kv_lora_rank
* (
hidden_size
+ num_attention_heads * (getattr(config, "qk_head_dim", 0) + getattr(config, "v_head_dim", 0))
)
+ hidden_size * getattr(config, "qk_pos_emb_head_dim", 0)
)
)
else:
num_query_groups = getattr(
config, "num_query_groups", getattr(config, "num_key_value_heads", num_attention_heads)
)
kv_flops = 4 * seqlen * hidden_size * num_query_groups * kv_channels
return q_flops + kv_flops


def _calculate_attention_flops(config: Any, seqlen: int) -> float:
num_attention_heads = config.num_attention_heads
kv_channels = getattr(config, "kv_channels", getattr(config, "head_dim", config.hidden_size // num_attention_heads))
if _is_mla(config):
flops = (
num_attention_heads
* seqlen
* seqlen
* (getattr(config, "qk_head_dim", 0) + getattr(config, "qk_pos_emb_head_dim", 0))
)
flops += num_attention_heads * seqlen * seqlen * getattr(config, "v_head_dim", kv_channels)
else:
flops = 2 * num_attention_heads * seqlen * seqlen * kv_channels
return flops


def _calculate_layer_flops(config: Any, seqlen: int, ffn_hidden_size: int) -> float:
hidden_size = config.hidden_size
return (
_calculate_qkv_projection_flops(config, seqlen)
+ _calculate_attention_flops(config, seqlen)
+ 2 * seqlen * hidden_size * hidden_size
+ 6 * seqlen * hidden_size * ffn_hidden_size
)


def calculate_packing_fwd_flops(seqlens: list[int], config: Any) -> float:
"""Model-aware forward FLOP estimate copied in spirit from slime."""
num_experts = getattr(config, "num_experts", getattr(config, "n_routed_experts", None))
dense_ffn = getattr(
config, "ffn_hidden_size", getattr(config, "intermediate_size", getattr(config, "moe_intermediate_size", 0))
)
if num_experts is None:
num_dense_layers = config.num_hidden_layers
num_moe_layers = 0
moe_ffn = dense_ffn
else:
moe_layer_freq = getattr(config, "moe_layer_freq", None)
if isinstance(moe_layer_freq, list):
num_dense_layers = sum(1 for freq in moe_layer_freq if freq == 0)
num_moe_layers = sum(1 for freq in moe_layer_freq if freq > 0)
elif isinstance(moe_layer_freq, int):
num_dense_layers = sum(1 for i in range(config.num_hidden_layers) if i % moe_layer_freq != 0)
num_moe_layers = config.num_hidden_layers - num_dense_layers
elif getattr(config, "first_k_dense_replace", None) is not None:
num_dense_layers = config.first_k_dense_replace
num_moe_layers = config.num_hidden_layers - num_dense_layers
else:
num_dense_layers = 0
num_moe_layers = config.num_hidden_layers

routed_topk = getattr(config, "moe_router_topk", getattr(config, "num_experts_per_tok", 1))
moe_ffn = getattr(config, "moe_ffn_hidden_size", getattr(config, "moe_intermediate_size", dense_ffn))
moe_ffn = moe_ffn * routed_topk + (getattr(config, "moe_shared_expert_intermediate_size", None) or 0)

total_flops = 0.0
for seqlen in seqlens:
if num_dense_layers > 0:
total_flops += _calculate_layer_flops(config, seqlen, dense_ffn) * num_dense_layers
if num_moe_layers > 0:
total_flops += _calculate_layer_flops(config, seqlen, moe_ffn) * num_moe_layers
total_flops += 2 * seqlen * config.hidden_size * config.vocab_size
return total_flops


def _copy_routed_experts(routed_experts: RoutedExperts) -> RoutedExperts:
return RoutedExperts(
data=routed_experts.data,
Expand Down Expand Up @@ -277,23 +161,17 @@ def add(self, lora_idx: int, sample: MicroBatch) -> None:
self.samples.append((lora_idx, sample))
self.length += len(sample.input_ids)

def workload(self, flops_config: Any | None) -> float:
if flops_config is None:
return self.length
return calculate_packing_fwd_flops([len(sample.input_ids) for _, sample in self.samples], flops_config)

def _sample_workload(self, sample: MicroBatch, flops_config: Any | None) -> float:
if flops_config is None:
return len(sample.input_ids)
return calculate_packing_fwd_flops([len(sample.input_ids)], flops_config)
@property
def workload(self) -> int:
return bin_cost(len(sample.input_ids) for _, sample in self.samples)

def split_by_workload(self, flops_config: Any | None) -> tuple["_MicroBatchBin", "_MicroBatchBin"]:
def split_by_workload(self) -> tuple["_MicroBatchBin", "_MicroBatchBin"]:
left: list[tuple[int, MicroBatch]] = []
right: list[tuple[int, MicroBatch]] = []
left_workload = 0.0
right_workload = 0.0
for lora_idx, sample in sorted(self.samples, key=lambda x: -self._sample_workload(x[1], flops_config)):
sample_workload = self._sample_workload(sample, flops_config)
left_workload = 0
right_workload = 0
for lora_idx, sample in sorted(self.samples, key=lambda x: -len(x[1].input_ids) ** 2):
sample_workload = len(sample.input_ids) ** 2
if left_workload <= right_workload:
left.append((lora_idx, sample))
left_workload += sample_workload
Expand Down Expand Up @@ -472,17 +350,15 @@ def _improve_partitions_by_swapping(weights: list[float], partitions: list[list[
)


def _expand_bins_by_splitting(bins: list[_MicroBatchBin], target_count: int, flops_config: Any | None) -> None:
def _expand_bins_by_splitting(bins: list[_MicroBatchBin], target_count: int) -> None:
while len(bins) < target_count:
candidates = [
(bin_content.workload(flops_config), idx)
for idx, bin_content in enumerate(bins)
if len(bin_content.samples) > 1
(bin_content.workload, idx) for idx, bin_content in enumerate(bins) if len(bin_content.samples) > 1
]
if not candidates:
break
_, idx = max(candidates)
left, right = bins[idx].split_by_workload(flops_config)
left, right = bins[idx].split_by_workload()
bins[idx] = left
bins.append(right)

Expand All @@ -492,7 +368,6 @@ def packed_samples_into_micro_bs(
max_seq_len: int,
num_loras: int,
num_train_workers: int,
flops_config: Any | None = None,
) -> list[MicroBatch]:
"""
Pack samples into micro_batch efficiently.
Expand Down Expand Up @@ -521,27 +396,18 @@ def packed_samples_into_micro_bs(
((len(bins) + num_train_workers - 1) // num_train_workers) * num_train_workers,
num_train_workers,
)
_expand_bins_by_splitting(bins, target_count, flops_config)
_expand_bins_by_splitting(bins, target_count)

return [_materialize_bin(bin_content, num_loras) for bin_content in bins]


def _distribute_group(
group: list[MicroBatch],
num_train_workers: int,
flops_config: Any | None,
) -> list[list[MicroBatch]]:
def _distribute_group(group: list[MicroBatch], num_train_workers: int) -> list[list[MicroBatch]]:
assert len(group) % num_train_workers == 0, "Number of micro batches is not divisible by number of data ranks"
if not group:
return [[] for _ in range(num_train_workers)]

if len(group) >= num_train_workers:
weights = [
calculate_packing_fwd_flops(micro_batch.sequence_lengths, flops_config)
if flops_config is not None
else len(micro_batch.input_ids)
for micro_batch in group
]
weights = [bin_cost(micro_batch.sequence_lengths) for micro_batch in group]
partitions = _balanced_partitions(weights, num_train_workers)
partitions = _improve_partitions_by_swapping(weights, partitions)
return [[group[i] for i in partition] for partition in partitions]
Expand Down Expand Up @@ -620,7 +486,6 @@ def prepare_batch(
idxs: list[int],
num_loras: int,
pad_to_multiple_of: int = 1,
flops_config: Any | None = None,
) -> list[list[MicroBatch]]:
"""
Prepare a batch of problems for each GPU. Each batch is a list of micro batches.
Expand All @@ -633,13 +498,7 @@ def prepare_batch(
"""
all_samples = [(idx, prepare_sample(rollout, seq_len)) for idx, rollout in zip(idxs, rollouts)]

micro_batches = packed_samples_into_micro_bs(
all_samples,
seq_len,
num_loras,
num_train_workers,
flops_config=flops_config,
)
micro_batches = packed_samples_into_micro_bs(all_samples, seq_len, num_loras, num_train_workers)
micro_batches = [pad_micro_batch(micro_batch, pad_to_multiple_of) for micro_batch in micro_batches]

# Separate by modality so each step index has uniform modality across all ranks
Expand All @@ -652,11 +511,7 @@ def prepare_batch(

batches_per_gpu: list[list[MicroBatch]] = [[] for _ in range(num_train_workers)]
for group in (mm_batches, text_batches):
group_batches_per_gpu = _distribute_group(
group,
num_train_workers,
flops_config,
)
group_batches_per_gpu = _distribute_group(group, num_train_workers)
for worker_idx, worker_batches in enumerate(group_batches_per_gpu):
batches_per_gpu[worker_idx].extend(worker_batches)

Expand Down
18 changes: 18 additions & 0 deletions src/prime_rl/trainer/cost_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Per-sequence forward-cost proxy used to balance packed micro-batches.

In a packed micro-batch with sequence-masked attention, each sequence
attends only within itself, so attention is O(n^2) per sequence while
linear ops (QKV proj, FFN, attn-out) are O(n). FFD packs bins to
~max_seq_len, so the linear term is approximately constant across bins
and inter-bin work variance is dominated by attention.

`bin_cost` returns just the n^2 term. It ranks bins correctly for
standard MHA / GQA setups. Add a model-aware estimator here only if a
specific model shows measured wallclock skew that this proxy misses.
"""

from collections.abc import Iterable


def bin_cost(seqlens: Iterable[int]) -> int:
return sum(n * n for n in seqlens)
2 changes: 0 additions & 2 deletions src/prime_rl/trainer/rl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __init__(
seq_len: int,
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
flops_config,
config: TransportConfig,
):
self.world = get_world()
Expand All @@ -179,7 +178,6 @@ def __init__(
tokenizer=tokenizer,
transport_config=config,
pad_to_multiple_of=pad_to_multiple_of,
flops_config=flops_config,
start_step=start_step,
)

Expand Down
21 changes: 5 additions & 16 deletions src/prime_rl/trainer/rl/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from transformers.tokenization_utils import PreTrainedTokenizer

from prime_rl.trainer.batch import get_packing_flops_config, prepare_batch
from prime_rl.trainer.batch import prepare_batch
from prime_rl.trainer.runs import get_multi_run_manager
from prime_rl.transport import (
MicroBatch,
Expand All @@ -32,7 +32,6 @@ def __init__(
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
flops_config=None,
start_step: int = 0,
):
self.logger = get_logger()
Expand All @@ -41,7 +40,6 @@ def __init__(
self.seq_len = seq_len
self.pad_to_multiple_of = pad_to_multiple_of
self.tokenizer = tokenizer
self.flops_config = get_packing_flops_config(flops_config) if flops_config is not None else None
self.receiver = setup_training_batch_receiver(config)
shutil.rmtree(get_rollout_dir(self.multi_run_manager.output_dir), ignore_errors=True)
self.sender: MicroBatchSender = setup_micro_batch_sender(
Expand Down Expand Up @@ -86,10 +84,9 @@ def __init__(
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
flops_config=None,
start_step: int = 0,
):
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, flops_config, start_step)
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step)
assert self.multi_run_manager.max_runs == 1, "SinglePacker only supports one run"

def pack(self):
Expand All @@ -113,7 +110,6 @@ def pack(self):
num_train_workers=self.dp_world_size,
idxs=[0] * len(batch.examples),
num_loras=self.multi_run_manager.max_runs,
flops_config=self.flops_config,
)

self.sender.send(micro_batch_grid)
Expand All @@ -127,10 +123,9 @@ def __init__(
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
config: TransportConfig,
flops_config=None,
start_step: int = 0,
):
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, flops_config, start_step)
super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step)
# Per-run buffer: stores (TrainingSample, step) tuples
self.buffers: list[deque[tuple[TrainingSample, int]]] = [
deque() for _ in range(self.multi_run_manager.max_runs)
Expand Down Expand Up @@ -332,7 +327,6 @@ def pack(self):
num_train_workers=self.dp_world_size,
idxs=[run_idx] * len(run_samples),
num_loras=self.multi_run_manager.max_runs,
flops_config=self.flops_config,
)
# Merge into combined grid
for worker_idx, worker_batches in enumerate(run_micro_batch_grid):
Expand All @@ -347,15 +341,10 @@ def setup_packer(
pad_to_multiple_of: int,
tokenizer: PreTrainedTokenizer,
transport_config: TransportConfig,
flops_config=None,
start_step: int = 0,
) -> BasePacker:
multi_run_manager = get_multi_run_manager()
if multi_run_manager.max_runs == 1:
return SinglePacker(
dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, flops_config, start_step
)
return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step)
else:
return MultiPacker(
dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, flops_config, start_step
)
return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step)
1 change: 0 additions & 1 deletion src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
config.model.seq_len,
config.model.cp,
tokenizer,
model.config if config.data.balance_by_flops else None,
config.rollout_transport,
)

Expand Down
Loading
Loading