diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 8ec019285..6404884e8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "utils/test_metric_utils.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 1e2a656c9..18a39c22f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -80,6 +80,7 @@ {'test_file': 'test_placement_group.py', 'num_gpus': 0}, {'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0}, {'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0}, + {'test_file': 'utils/test_metric_utils.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_path_loading_contracts.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 7c12a5a77..ce9376a4e 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -296,6 +296,7 @@ def log_rollout_data( "multimodal_train_inputs", "loss_masks", "sample_indices", + "group_indices", "rollout_ids", "rollout_mask_sums", "rollout_routed_experts", @@ -490,8 +491,11 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - """ Compute pass@k metrics from `raw_reward` groups and log the results. - `raw_reward` is reshaped to `[group_number, group_size]`, then pass@k is - estimated per problem and averaged. + When every sample carries a `group_index` (packed parallel to `raw_reward` + by `_convert_samples_to_train_data`), rewards are bucketed by their actual + prompt group, so over-sampled / filtered batches whose total is not + `rollout_batch_size * n_samples_per_prompt` still report pass@k. Otherwise + `raw_reward` is reshaped to `[group_number, group_size]` as before. """ if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage(): log_dict = {} @@ -499,11 +503,21 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - if key != "raw_reward": continue - log_dict |= compute_pass_rate( - flat_rewards=val, - group_size=args.n_samples_per_prompt, - num_groups=args.rollout_batch_size, - ) + group_indices = rollout_data.get("group_indices") + if group_indices is not None and all(idx is not None for idx in group_indices): + log_dict |= compute_pass_rate( + flat_rewards=val, + group_size=args.n_samples_per_prompt, + group_ids=group_indices, + ) + else: + # Custom rollout/convert functions may not tag samples with + # group_index; keep the rigid legacy layout for them. + log_dict |= compute_pass_rate( + flat_rewards=val, + group_size=args.n_samples_per_prompt, + num_groups=args.rollout_batch_size, + ) gather_log_data("passrate", args, rollout_id, log_dict) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 3cbdb7e2c..32552dd27 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -712,6 +712,9 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl "raw_reward": raw_rewards, "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples], "sample_indices": [sample.index for sample in samples], + # Parallel to raw_reward: lets pass-rate logging bucket ragged + # (over-sampled / filtered) batches by their actual prompt group. + "group_indices": [sample.group_index for sample in samples], "rollout_ids": rollout_ids, } @@ -830,7 +833,7 @@ def _split_train_data_by_dp(self, data): continue rollout_data[key] = [data[key][j] for j in partition] # keys that need to be splited at train side - for key in ["raw_reward", "total_lengths"]: + for key in ["raw_reward", "group_indices", "total_lengths"]: if key not in data: continue rollout_data[key] = data[key] @@ -1215,10 +1218,19 @@ def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] truncated = data[key]["truncated"] log_dict[f"eval/{key}-truncated_ratio"] = sum(truncated) / len(truncated) if args.log_passrate: + # The default eval path never sets Sample.group_index (samples are + # copied straight from the eval dataset), so it keeps the rigid + # n_samples_per_eval_prompt layout. Rollout functions that tag + # every sample with group_index get ragged bucketing instead of + # the divisibility assert. + group_ids = None + if samples and all(s.group_index is not None for s in samples): + group_ids = [s.group_index for s in samples] log_dict |= dict_add_prefix( compute_pass_rate( flat_rewards=rewards, group_size=args.n_samples_per_eval_prompt, + group_ids=group_ids, ), f"eval/{key}-", ) diff --git a/slime/utils/metric_utils.py b/slime/utils/metric_utils.py index 46e42d73b..b1b4348ed 100644 --- a/slime/utils/metric_utils.py +++ b/slime/utils/metric_utils.py @@ -15,27 +15,67 @@ def compute_pass_rate( flat_rewards: list[float], group_size: int, num_groups: int | None = None, + group_ids: list | None = None, ): + """Estimate pass@k per prompt-group and average across groups. + + Two regimes: + + * **Rigid** (``group_ids is None``): assume every group has exactly + ``group_size`` samples laid out contiguously, i.e. + ``len(flat_rewards) == num_groups * group_size``. This is the legacy + fixed-size shape and is numerically identical to the prior reshape. + + * **Ragged** (``group_ids`` given): over-sampled batches do *not* keep a + rigid ``num_groups * group_size`` layout — dynamic sampling, group + replacement, and per-group sample drops yield a variable number of + samples per prompt-group (and a total that need not be a multiple of + ``group_size``). We then bucket ``flat_rewards`` by their actual group + id and estimate pass@k over the samples that actually exist for each + group, so the metric never asserts on a ragged batch. + + ``num_groups`` is only used by the rigid path (to validate the reshape); + it is ignored when ``group_ids`` is provided, since the ragged path derives + the group count from the distinct ids in ``group_ids``. + + ``group_size`` only sets which ``pass@{1,2,4,...}`` rungs are reported + (the ladder ``[2**i for i in range(log2(group_size)+1)]``). In the ragged + regime, pass@k for a rung is averaged only over groups that have at least + ``k`` samples — a group with fewer than ``k`` samples cannot define an + unbiased pass@k draw, so it is excluded from that rung's mean rather than + counted as a trivial 1.0. Rungs whose every group is too small are dropped. + """ if group_size == 1: return {} - if num_groups is None: - num_groups = len(flat_rewards) // group_size - pass_rate_name_list = [2**i for i in range(int(math.log2(group_size)) + 1)] - assert len(flat_rewards) == num_groups * group_size, f"{len(flat_rewards)=} {num_groups=} {group_size=}" - rewards_of_group = np.array(flat_rewards).reshape(num_groups, group_size) + if group_ids is None: + if num_groups is None: + num_groups = len(flat_rewards) // group_size + assert len(flat_rewards) == num_groups * group_size, f"{len(flat_rewards)=} {num_groups=} {group_size=}" + rewards_of_group = np.array(flat_rewards).reshape(num_groups, group_size) + num_samples_per_group = np.full(num_groups, group_size) + num_correct_per_group = np.sum(rewards_of_group == 1, axis=1) + else: + # Ragged layout: bucket rewards by their actual group id. Group order + # does not matter — the final metric is an order-independent mean. + assert len(flat_rewards) == len(group_ids), f"{len(flat_rewards)=} {len(group_ids)=}" + grouped: dict = {} + for reward, gid in zip(flat_rewards, group_ids, strict=True): + grouped.setdefault(gid, []).append(reward) + group_rewards = list(grouped.values()) + num_samples_per_group = np.array([len(g) for g in group_rewards]) + num_correct_per_group = np.array([sum(1 for r in g if r == 1) for g in group_rewards]) log_dict = {} for k in pass_rate_name_list: - num_correct = np.sum(rewards_of_group == 1, axis=1) - num_samples = np.full(num_groups, group_size) - - pass_k_estimates = _estimate_pass_at_k(num_samples, num_correct, k) - - pass_k = np.mean(pass_k_estimates) - log_dict[f"pass@{k}"] = pass_k + # A group must have >= k samples to define an unbiased pass@k draw. + eligible = num_samples_per_group >= k + if not np.any(eligible): + continue + pass_k_estimates = _estimate_pass_at_k(num_samples_per_group[eligible], num_correct_per_group[eligible], k) + log_dict[f"pass@{k}"] = np.mean(pass_k_estimates) return log_dict diff --git a/tests/utils/test_metric_utils.py b/tests/utils/test_metric_utils.py new file mode 100644 index 000000000..ea9911265 --- /dev/null +++ b/tests/utils/test_metric_utils.py @@ -0,0 +1,109 @@ +"""Unit tests for ``slime.utils.metric_utils.compute_pass_rate``. + +Pins the two regimes: + +* rigid (legacy fixed-size): ``len(flat_rewards) == num_groups * group_size``. +* ragged (over-sampled): variable per-group sample counts whose total need not + divide ``group_size`` — the case that crashed the base-slime metric assert + (e.g. 51 trainable samples, groups of mixed size, not the rigid 8*4=32). +""" + +import pytest + +from slime.utils.metric_utils import compute_pass_rate + + +def test_group_size_one_returns_empty(): + assert compute_pass_rate([1, 0, 1], group_size=1) == {} + + +def test_rigid_layout_pass_at_k(): + # 2 groups x 4 samples; group A has 2 correct, group B has 0 correct. + flat = [1, 1, 0, 0, 0, 0, 0, 0] + out = compute_pass_rate(flat, group_size=4, num_groups=2) + assert set(out) == {"pass@1", "pass@2", "pass@4"} + # pass@1 = mean correct fraction = (2/4 + 0/4) / 2 = 0.25 + assert out["pass@1"] == pytest.approx(0.25) + # pass@4 over the whole group: group A always has >=1 correct in a draw of 4 + # (n-c=2 < k=4 -> 1.0), group B never -> mean 0.5. + assert out["pass@4"] == pytest.approx(0.5) + + +def test_rigid_layout_matches_legacy_reshape(): + # The group_ids=None path must stay numerically identical to the legacy + # reshape: full-size groups, every rung eligible, no rung dropped. + flat = [1, 0, 1, 1, 0, 0, 1, 1] + out = compute_pass_rate(flat, group_size=4, num_groups=2) + assert set(out) == {"pass@1", "pass@2", "pass@4"} + # group A=[1,0,1,1] (3/4 correct), group B=[0,0,1,1] (2/4 correct). + assert out["pass@1"] == pytest.approx((3 / 4 + 2 / 4) / 2) + + +def test_rigid_layout_asserts_on_bad_count(): + with pytest.raises(AssertionError): + compute_pass_rate([1, 0, 1], group_size=4, num_groups=2) + + +def test_ragged_oversampled_reproduces_crash_and_pins_values(): + # The exact shape that crashed base-slime: 51 samples across ragged groups, + # total not a multiple of group_size (4). Group sizes: twelve groups + # summing to 51, every group filled as [1,0,1,0,...]. + group_sizes = [4, 4, 3, 4, 4, 4, 5, 4, 4, 4, 4, 7] + assert sum(group_sizes) == 51 + flat_rewards = [] + group_ids = [] + for gi, n in enumerate(group_sizes): + for j in range(n): + flat_rewards.append(1 if j % 2 == 0 else 0) + group_ids.append(f"task-{gi}") + + # The rigid path reshapes (num_groups, group_size) and asserts the total + # divides group_size: 51 != 12*4, so it crashes — this is the bug. + with pytest.raises(AssertionError): + compute_pass_rate(flat_rewards, group_size=4, num_groups=12) + + # The ragged path buckets by group id and never asserts. Pin the exact + # pass@k the fix establishes for this input (not just a 0..1 range). + out = compute_pass_rate(flat_rewards, group_size=4, group_ids=group_ids) + assert set(out) == {"pass@1", "pass@2", "pass@4"} + assert out["pass@1"] == pytest.approx(0.5281746031746032) + assert out["pass@2"] == pytest.approx(0.8547619047619048) + # Every group has >= 1 correct in any draw of 4 (n-c < 4 for all), so 1.0. + assert out["pass@4"] == pytest.approx(1.0) + + +def test_ragged_per_group_semantics(): + # Two groups: A = [1,1,0] (3 samples, 2 correct), B = [0,0] (2 samples, 0 correct). + flat = [1, 1, 0, 0, 0] + gids = ["a", "a", "a", "b", "b"] + out = compute_pass_rate(flat, group_size=4, group_ids=gids) + # rungs for group_size=4 -> {1,2,4}; group A has 3 samples, group B has 2. + # pass@1: mean correct frac = (2/3 + 0/2)/2 = 1/3. + assert out["pass@1"] == pytest.approx(1 / 3) + # pass@2: both groups have >=2 samples (eligible). + # A: n=3,c=2,k=2 -> n-c=1 < k=2 -> 1.0 + # B: n=2,c=0,k=2 -> n-c=2 >= k=2 -> 0.0 + # mean -> 0.5 + assert out["pass@2"] == pytest.approx(0.5) + # pass@4: only groups with >=4 samples are eligible; neither qualifies -> rung dropped. + assert "pass@4" not in out + + +def test_ragged_all_groups_too_small_drops_high_rungs(): + # All groups have a single sample -> only pass@1 survives (pass@2/4 dropped). + flat = [1, 0, 1] + gids = ["a", "b", "c"] + out = compute_pass_rate(flat, group_size=4, group_ids=gids) + assert "pass@1" in out + assert "pass@2" not in out + assert "pass@4" not in out + assert out["pass@1"] == pytest.approx(2 / 3) + + +def test_ragged_length_mismatch_asserts(): + with pytest.raises(AssertionError): + compute_pass_rate([1, 0, 1], group_size=4, group_ids=["a", "b"]) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))