Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "utils/test_metric_utils.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
{'test_file': 'test_placement_group.py', 'num_gpus': 0},
{'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0},
{'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0},
{'test_file': 'utils/test_metric_utils.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_path_loading_contracts.py', 'num_gpus': 0},
Expand Down
28 changes: 21 additions & 7 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def log_rollout_data(
"multimodal_train_inputs",
"loss_masks",
"sample_indices",
"group_indices",
"rollout_ids",
"rollout_mask_sums",
"rollout_routed_experts",
Expand Down Expand Up @@ -490,20 +491,33 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -
"""
Compute pass@k metrics from `raw_reward` groups and log the results.

`raw_reward` is reshaped to `[group_number, group_size]`, then pass@k is
estimated per problem and averaged.
When every sample carries a `group_index` (packed parallel to `raw_reward`
by `_convert_samples_to_train_data`), rewards are bucketed by their actual
prompt group, so over-sampled / filtered batches whose total is not
`rollout_batch_size * n_samples_per_prompt` still report pass@k. Otherwise
`raw_reward` is reshaped to `[group_number, group_size]` as before.
"""
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():
log_dict = {}
for key, val in rollout_data.items():
if key != "raw_reward":
continue

log_dict |= compute_pass_rate(
flat_rewards=val,
group_size=args.n_samples_per_prompt,
num_groups=args.rollout_batch_size,
)
group_indices = rollout_data.get("group_indices")
if group_indices is not None and all(idx is not None for idx in group_indices):
log_dict |= compute_pass_rate(
flat_rewards=val,
group_size=args.n_samples_per_prompt,
group_ids=group_indices,
)
else:
# Custom rollout/convert functions may not tag samples with
# group_index; keep the rigid legacy layout for them.
log_dict |= compute_pass_rate(
flat_rewards=val,
group_size=args.n_samples_per_prompt,
num_groups=args.rollout_batch_size,
)

gather_log_data("passrate", args, rollout_id, log_dict)

Expand Down
14 changes: 13 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
"raw_reward": raw_rewards,
"truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
"sample_indices": [sample.index for sample in samples],
# Parallel to raw_reward: lets pass-rate logging bucket ragged
# (over-sampled / filtered) batches by their actual prompt group.
"group_indices": [sample.group_index for sample in samples],
"rollout_ids": rollout_ids,
}

Expand Down Expand Up @@ -830,7 +833,7 @@ def _split_train_data_by_dp(self, data):
continue
rollout_data[key] = [data[key][j] for j in partition]
# keys that need to be splited at train side
for key in ["raw_reward", "total_lengths"]:
for key in ["raw_reward", "group_indices", "total_lengths"]:
if key not in data:
continue
rollout_data[key] = data[key]
Expand Down Expand Up @@ -1215,10 +1218,19 @@ def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any]
truncated = data[key]["truncated"]
log_dict[f"eval/{key}-truncated_ratio"] = sum(truncated) / len(truncated)
if args.log_passrate:
# The default eval path never sets Sample.group_index (samples are
# copied straight from the eval dataset), so it keeps the rigid
# n_samples_per_eval_prompt layout. Rollout functions that tag
# every sample with group_index get ragged bucketing instead of
# the divisibility assert.
group_ids = None
if samples and all(s.group_index is not None for s in samples):
group_ids = [s.group_index for s in samples]
log_dict |= dict_add_prefix(
compute_pass_rate(
flat_rewards=rewards,
group_size=args.n_samples_per_eval_prompt,
group_ids=group_ids,
),
f"eval/{key}-",
)
Expand Down
64 changes: 52 additions & 12 deletions slime/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,67 @@ def compute_pass_rate(
flat_rewards: list[float],
group_size: int,
num_groups: int | None = None,
group_ids: list | None = None,
):
"""Estimate pass@k per prompt-group and average across groups.

Two regimes:

* **Rigid** (``group_ids is None``): assume every group has exactly
``group_size`` samples laid out contiguously, i.e.
``len(flat_rewards) == num_groups * group_size``. This is the legacy
fixed-size shape and is numerically identical to the prior reshape.

* **Ragged** (``group_ids`` given): over-sampled batches do *not* keep a
rigid ``num_groups * group_size`` layout — dynamic sampling, group
replacement, and per-group sample drops yield a variable number of
samples per prompt-group (and a total that need not be a multiple of
``group_size``). We then bucket ``flat_rewards`` by their actual group
id and estimate pass@k over the samples that actually exist for each
group, so the metric never asserts on a ragged batch.

``num_groups`` is only used by the rigid path (to validate the reshape);
it is ignored when ``group_ids`` is provided, since the ragged path derives
the group count from the distinct ids in ``group_ids``.

``group_size`` only sets which ``pass@{1,2,4,...}`` rungs are reported
(the ladder ``[2**i for i in range(log2(group_size)+1)]``). In the ragged
regime, pass@k for a rung is averaged only over groups that have at least
``k`` samples — a group with fewer than ``k`` samples cannot define an
unbiased pass@k draw, so it is excluded from that rung's mean rather than
counted as a trivial 1.0. Rungs whose every group is too small are dropped.
"""
if group_size == 1:
return {}

if num_groups is None:
num_groups = len(flat_rewards) // group_size

pass_rate_name_list = [2**i for i in range(int(math.log2(group_size)) + 1)]

assert len(flat_rewards) == num_groups * group_size, f"{len(flat_rewards)=} {num_groups=} {group_size=}"
rewards_of_group = np.array(flat_rewards).reshape(num_groups, group_size)
if group_ids is None:
if num_groups is None:
num_groups = len(flat_rewards) // group_size
assert len(flat_rewards) == num_groups * group_size, f"{len(flat_rewards)=} {num_groups=} {group_size=}"
rewards_of_group = np.array(flat_rewards).reshape(num_groups, group_size)
num_samples_per_group = np.full(num_groups, group_size)
num_correct_per_group = np.sum(rewards_of_group == 1, axis=1)
else:
# Ragged layout: bucket rewards by their actual group id. Group order
# does not matter — the final metric is an order-independent mean.
assert len(flat_rewards) == len(group_ids), f"{len(flat_rewards)=} {len(group_ids)=}"
grouped: dict = {}
for reward, gid in zip(flat_rewards, group_ids, strict=True):
grouped.setdefault(gid, []).append(reward)
group_rewards = list(grouped.values())
num_samples_per_group = np.array([len(g) for g in group_rewards])
num_correct_per_group = np.array([sum(1 for r in g if r == 1) for g in group_rewards])

log_dict = {}
for k in pass_rate_name_list:
num_correct = np.sum(rewards_of_group == 1, axis=1)
num_samples = np.full(num_groups, group_size)

pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k)

pass_k = np.mean(pass_k_estimates)
log_dict[f"pass@{k}"] = pass_k
# A group must have >= k samples to define an unbiased pass@k draw.
eligible = num_samples_per_group >= k
if not np.any(eligible):
continue
pass_k_estimates = _estimate_pass_at_k(num_samples_per_group[eligible], num_correct_per_group[eligible], k)
log_dict[f"pass@{k}"] = np.mean(pass_k_estimates)

return log_dict

Expand Down
109 changes: 109 additions & 0 deletions tests/utils/test_metric_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Unit tests for ``slime.utils.metric_utils.compute_pass_rate``.

Pins the two regimes:

* rigid (legacy fixed-size): ``len(flat_rewards) == num_groups * group_size``.
* ragged (over-sampled): variable per-group sample counts whose total need not
divide ``group_size`` — the case that crashed the base-slime metric assert
(e.g. 51 trainable samples, groups of mixed size, not the rigid 8*4=32).
"""

import pytest

from slime.utils.metric_utils import compute_pass_rate


def test_group_size_one_returns_empty():
assert compute_pass_rate([1, 0, 1], group_size=1) == {}


def test_rigid_layout_pass_at_k():
# 2 groups x 4 samples; group A has 2 correct, group B has 0 correct.
flat = [1, 1, 0, 0, 0, 0, 0, 0]
out = compute_pass_rate(flat, group_size=4, num_groups=2)
assert set(out) == {"pass@1", "pass@2", "pass@4"}
# pass@1 = mean correct fraction = (2/4 + 0/4) / 2 = 0.25
assert out["pass@1"] == pytest.approx(0.25)
# pass@4 over the whole group: group A always has >=1 correct in a draw of 4
# (n-c=2 < k=4 -> 1.0), group B never -> mean 0.5.
assert out["pass@4"] == pytest.approx(0.5)


def test_rigid_layout_matches_legacy_reshape():
# The group_ids=None path must stay numerically identical to the legacy
# reshape: full-size groups, every rung eligible, no rung dropped.
flat = [1, 0, 1, 1, 0, 0, 1, 1]
out = compute_pass_rate(flat, group_size=4, num_groups=2)
assert set(out) == {"pass@1", "pass@2", "pass@4"}
# group A=[1,0,1,1] (3/4 correct), group B=[0,0,1,1] (2/4 correct).
assert out["pass@1"] == pytest.approx((3 / 4 + 2 / 4) / 2)


def test_rigid_layout_asserts_on_bad_count():
with pytest.raises(AssertionError):
compute_pass_rate([1, 0, 1], group_size=4, num_groups=2)


def test_ragged_oversampled_reproduces_crash_and_pins_values():
# The exact shape that crashed base-slime: 51 samples across ragged groups,
# total not a multiple of group_size (4). Group sizes: twelve groups
# summing to 51, every group filled as [1,0,1,0,...].
group_sizes = [4, 4, 3, 4, 4, 4, 5, 4, 4, 4, 4, 7]
assert sum(group_sizes) == 51
flat_rewards = []
group_ids = []
for gi, n in enumerate(group_sizes):
for j in range(n):
flat_rewards.append(1 if j % 2 == 0 else 0)
group_ids.append(f"task-{gi}")

# The rigid path reshapes (num_groups, group_size) and asserts the total
# divides group_size: 51 != 12*4, so it crashes — this is the bug.
with pytest.raises(AssertionError):
compute_pass_rate(flat_rewards, group_size=4, num_groups=12)

# The ragged path buckets by group id and never asserts. Pin the exact
# pass@k the fix establishes for this input (not just a 0..1 range).
out = compute_pass_rate(flat_rewards, group_size=4, group_ids=group_ids)
assert set(out) == {"pass@1", "pass@2", "pass@4"}
assert out["pass@1"] == pytest.approx(0.5281746031746032)
assert out["pass@2"] == pytest.approx(0.8547619047619048)
# Every group has >= 1 correct in any draw of 4 (n-c < 4 for all), so 1.0.
assert out["pass@4"] == pytest.approx(1.0)


def test_ragged_per_group_semantics():
# Two groups: A = [1,1,0] (3 samples, 2 correct), B = [0,0] (2 samples, 0 correct).
flat = [1, 1, 0, 0, 0]
gids = ["a", "a", "a", "b", "b"]
out = compute_pass_rate(flat, group_size=4, group_ids=gids)
# rungs for group_size=4 -> {1,2,4}; group A has 3 samples, group B has 2.
# pass@1: mean correct frac = (2/3 + 0/2)/2 = 1/3.
assert out["pass@1"] == pytest.approx(1 / 3)
# pass@2: both groups have >=2 samples (eligible).
# A: n=3,c=2,k=2 -> n-c=1 < k=2 -> 1.0
# B: n=2,c=0,k=2 -> n-c=2 >= k=2 -> 0.0
# mean -> 0.5
assert out["pass@2"] == pytest.approx(0.5)
# pass@4: only groups with >=4 samples are eligible; neither qualifies -> rung dropped.
assert "pass@4" not in out


def test_ragged_all_groups_too_small_drops_high_rungs():
# All groups have a single sample -> only pass@1 survives (pass@2/4 dropped).
flat = [1, 0, 1]
gids = ["a", "b", "c"]
out = compute_pass_rate(flat, group_size=4, group_ids=gids)
assert "pass@1" in out
assert "pass@2" not in out
assert "pass@4" not in out
assert out["pass@1"] == pytest.approx(2 / 3)


def test_ragged_length_mismatch_asserts():
with pytest.raises(AssertionError):
compute_pass_rate([1, 0, 1], group_size=4, group_ids=["a", "b"])


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading