Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions slime/utils/dp_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
single first-fit pass (dynamic batch) or fixed-size chunking
(static batch).
3. Adjust ``K`` to a multiple of ``dp_size * (mb_group if vpp>1 else 1)``
by splitting the largest multi-sample bins (dynamic only).
by splitting the largest multi-sample bins (dynamic only); if up-rounding
the mbs count exceeds the step's sample count, first drop whole trailing
rollouts until the aligned target is reachable.
4. Distribute the ``K`` mbs across ``dp_size`` ranks, ``K / dp_size``
each, with either a strided round-robin or a Karmarkar-Karp pass on
estimated mbs FLOPs.
Expand Down Expand Up @@ -106,8 +108,8 @@ def build_dp_schedule(

Returns:
``(partitions, micro_batch_indices, num_microbatches, global_batch_sizes)``.
``global_batch_sizes[s]`` = rollout count for step s (constant
``global_batch_size`` for every step).
``global_batch_sizes[s]`` = kept rollout count for step s (may be
``< global_batch_size`` when trailing rollouts are dropped).
"""
dp_size = train_parallel_config["dp_size"]
cp_size = train_parallel_config["cp_size"]
Expand Down Expand Up @@ -143,37 +145,79 @@ def build_dp_schedule(
num_microbatches: list[int] = []
global_batch_sizes: list[int] = []

def _collect_step_samples(step_rollouts: list[int]) -> tuple[list[int], list[int]]:
indices = [pos for rid in step_rollouts for pos in rollout_id_to_samples[rid]]
return indices, [total_lengths[i] for i in indices]

def _pack(step_lengths: list[int]) -> list[list[int]]:
return _pack_step_into_mbs(
step_lengths,
args=args,
use_dynamic_batch_size=args.use_dynamic_batch_size,
max_per_bin=max_per_bin,
micro_batch_size=getattr(args, "micro_batch_size", None),
balance_by_flops=args.balance_by_flops,
)

def _aligned_target(num_mbs: int) -> int:
"""mbs count rounded up to the next multiple of ``align_to`` (>= align_to)."""
return max(((num_mbs + align_to - 1) // align_to) * align_to, align_to)

for step_i in range(num_steps):
step_rollouts = rollout_ids[step_i * global_batch_size : (step_i + 1) * global_batch_size]
sample_indices = [pos for rid in step_rollouts for pos in rollout_id_to_samples[rid]]
step_lengths = [total_lengths[i] for i in sample_indices]
global_batch_sizes.append(global_batch_size)
sample_indices, step_lengths = _collect_step_samples(step_rollouts)
assert len(sample_indices) >= dp_size, (
f"step {step_i}: {len(sample_indices)} samples < dp_size {dp_size}; "
f"each step needs at least one sample per rank."
)

# 1. Pack samples in this step into mbs with one global pass.
# ``step_mbs`` indices are LOCAL into ``sample_indices``.
step_mbs = _pack_step_into_mbs(
step_lengths,
args=args,
use_dynamic_batch_size=args.use_dynamic_batch_size,
max_per_bin=max_per_bin,
micro_batch_size=getattr(args, "micro_batch_size", None),
balance_by_flops=args.balance_by_flops,
)
step_mbs = _pack(step_lengths)

if args.use_dynamic_batch_size and align_to > 1:
dropped_rollouts = 0
while (
_aligned_target(len(step_mbs)) > len(sample_indices)
and len(sample_indices) - len(rollout_id_to_samples[step_rollouts[-1]]) >= align_to
):
step_rollouts.pop()
dropped_rollouts += 1
sample_indices, step_lengths = _collect_step_samples(step_rollouts)
step_mbs = _pack(step_lengths)
if dropped_rollouts:
logger.warning(
"[dp_schedule] step %d: dropped %d trailing rollout(s) (%d kept, %d samples) so the "
"aligned micro-batch target stays reachable (dp_size=%d, align_to=%d).",
step_i,
dropped_rollouts,
len(step_rollouts),
len(sample_indices),
dp_size,
align_to,
)

global_batch_sizes.append(len(step_rollouts))

# 2. Align mbs count to a multiple of ``align_to``.
target_K = max(((len(step_mbs) + align_to - 1) // align_to) * align_to, align_to)
target_K = _aligned_target(len(step_mbs))
if target_K != len(step_mbs):
if args.use_dynamic_batch_size:
expand_bins_by_splitting(step_mbs, target_K, step_lengths)
assert len(step_mbs) == target_K, (
f"dynamic path: could only produce {len(step_mbs)} mbs after maximal splitting; "
f"need {target_K}. step {step_i} has {len(sample_indices)} samples, below the "
f"alignment threshold ({align_to})."
)
if len(step_mbs) != target_K:
# Rollout atomicity means no kept prefix may land on a multiple
# of align_to; raise with an actionable message so the operator
# can retune global_batch_size / n_samples_per_prompt.
raise ValueError(
f"dp_schedule step {step_i}: cannot align micro-batches to a multiple of "
f"align_to={align_to} (dp_size={dp_size}). After dropping trailing rollouts the "
f"step has {len(sample_indices)} samples packed into {len(step_mbs)} singleton "
f"micro-batches, but the aligned target is {target_K} and singleton bins cannot "
f"split further. This happens with ragged rollout sizes where every long sample "
f"fills its own micro-batch. Adjust global_batch_size / n_samples_per_prompt "
f"(or max_tokens_per_gpu) so each step's kept-sample count can reach a multiple "
f"of align_to."
)
else:
raise AssertionError(
f"static path: num_mbs ({len(step_mbs)}) is not a multiple of "
Expand Down
131 changes: 131 additions & 0 deletions tests/test_dp_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,136 @@ def test_rejects_when_fewer_rollouts_than_gbs():
build_dp_schedule(args, tp, [3] * 6, global_batch_size=4, rollout_indices=[0, 0, 1, 1, 2, 2])


@pytest.mark.unit
def test_dynamic_long_trajectories_ragged_samples_drop_trailing_rollouts():
"""DP4, all-long samples (1 per bin) with a sample count that does not tile.

512 rollouts where two fan out to 2 samples each -> 514 samples. Packing
yields ~514 singleton bins; the pre-fix align-up target (516) exceeded the
sample count and the splitter assert fired. The fix
drops the minimum number of whole trailing rollouts so the target stays
reachable, keeps GRPO groups atomic, and records the kept rollout count in
``global_batch_sizes``.
"""
dp_size = 4
num_rollouts = 512
rollout_indices = []
total_lengths = []
for g in range(num_rollouts):
fan_out = 2 if g < 2 else 1 # 514 samples total
for _ in range(fan_out):
rollout_indices.append(g)
total_lengths.append(30_000) # ~1 sample per 32k bin

args = make_args(use_dynamic_batch_size=True, max_tokens_per_gpu=32_768)
tp = make_tp(dp_size=dp_size)

partitions, micro_batch_indices, num_microbatches, global_batch_sizes = build_dp_schedule(
args,
tp,
total_lengths,
global_batch_size=num_rollouts,
rollout_indices=rollout_indices,
)

# 514 samples -> align-up target 516 is unreachable; dropping the 2 trailing
# singleton rollouts lands on 510 kept rollouts / 512 samples (a multiple of dp_size).
assert global_batch_sizes == [510]
kept_rollouts = global_batch_sizes[0]

assert len(set(len(mbi) for mbi in micro_batch_indices)) == 1
covered = sorted(i for part in partitions for i in part)
assert len(covered) == 512
covered_rollouts = {rollout_indices[i] for i in covered}
assert covered_rollouts == set(range(kept_rollouts)) # trailing rollouts dropped, no holes
for g in covered_rollouts:
expected = [i for i, gid in enumerate(rollout_indices) if gid == g]
assert [i for i in covered if rollout_indices[i] == g] == expected # rollouts stay whole

total_mbs = sum(len(mbi) for mbi in micro_batch_indices)
assert total_mbs % dp_size == 0
assert num_microbatches[0] == total_mbs // dp_size


@pytest.mark.unit
def test_dynamic_ragged_rollout_sizes_drop_below_dp_size_floor():
"""DP4, ragged rollout sizes where the aligned target is only reachable after
dropping more trailing rollouts than the old ``len > dp_size`` floor allowed.

Rollout sizes [3, 2, 3, 1, 2] (all long -> 1 sample per bin) total 11 samples.
No prefix of >4 rollouts has a sample count that is a multiple of 4, so the
old rollout-count floor exited with the target still unreachable and the
splitter assert fired. The sample-count floor keeps dropping (here down to 3
rollouts / 8 samples) so alignment is reached without a crash.
"""
dp_size = 4
sizes = [3, 2, 3, 1, 2]
rollout_indices = []
total_lengths = []
for g, n in enumerate(sizes):
for _ in range(n):
rollout_indices.append(g)
total_lengths.append(30_000)

args = make_args(use_dynamic_batch_size=True, max_tokens_per_gpu=32_768)
tp = make_tp(dp_size=dp_size)

partitions, micro_batch_indices, num_microbatches, global_batch_sizes = build_dp_schedule(
args,
tp,
total_lengths,
global_batch_size=len(sizes),
rollout_indices=rollout_indices,
)

# Dropped past the old dp_size floor (which would have kept 4 rollouts / 9
# samples, still unaligned) down to the 3-rollout / 8-sample prefix that tiles
# by align_to.
assert global_batch_sizes == [3]
kept_rollouts = global_batch_sizes[0]
kept_samples = sum(sizes[:kept_rollouts])
assert kept_samples == 8
assert kept_samples % dp_size == 0

total_mbs = sum(len(mbi) for mbi in micro_batch_indices)
assert total_mbs % dp_size == 0
assert len(set(len(mbi) for mbi in micro_batch_indices)) == 1

covered = sorted(i for part in partitions for i in part)
assert {rollout_indices[i] for i in covered} == set(range(kept_rollouts))


@pytest.mark.unit
def test_dynamic_ragged_unreachable_alignment_raises_actionable_error():
"""DP4, ragged rollout sizes where NO prefix of whole rollouts has a sample
count that is a multiple of align_to (rollout atomicity makes alignment
impossible). The scheduler must fail loud with an actionable ValueError
naming the step / sample count / align_to, not a bare AssertionError.

Rollout sizes [1, 1, 4, 1, 3, 1] (all long) have prefix sample counts
1, 2, 6, 7, 10, 11 — none divisible by 4.
"""
dp_size = 4
sizes = [1, 1, 4, 1, 3, 1]
rollout_indices = []
total_lengths = []
for g, n in enumerate(sizes):
for _ in range(n):
rollout_indices.append(g)
total_lengths.append(30_000)

args = make_args(use_dynamic_batch_size=True, max_tokens_per_gpu=32_768)
tp = make_tp(dp_size=dp_size)

with pytest.raises(ValueError, match="cannot align micro-batches"):
build_dp_schedule(
args,
tp,
total_lengths,
global_batch_size=len(sizes),
rollout_indices=rollout_indices,
)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading