Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
{'test_file': 'test_metric_report.py', 'num_gpus': 0},
{'test_file': 'test_metric_report_dist.py', 'num_gpus': 0},
{'test_file': 'test_loss_cp_invariance.py', 'num_gpus': 0},
{'test_file': 'test_logprob_entropy_fused.py', 'num_gpus': 0},
{'test_file': 'test_value_temperature.py', 'num_gpus': 0},
{'test_file': 'test_cispo_loss.py', 'num_gpus': 0},
{'test_file': 'test_rm_f1.py', 'num_gpus': 0},
Expand Down
7 changes: 6 additions & 1 deletion examples/retool/retool_qwen3_4b_rl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ PERF_ARGS=(
# --micro-batch-size 1
--use-dynamic-batch-size
--max-tokens-per-gpu 9216
# Bound the fused cross-entropy [tokens, vocab] transient that can OOM on long retool traces.
--log-probs-chunk-size 1024
# Gather only response tokens before the cross-entropy: retool traces have long prompts, so this
# shrinks the [T, vocab] tensor to [T_response, vocab] and stacks with the chunking above.
--log-probs-response-only
)

GRPO_ARGS=(
Expand Down Expand Up @@ -153,4 +158,4 @@ ray job submit --address="http://127.0.0.1:8265" \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]} \
${CUSTOM_ARGS[@]}
${CUSTOM_ARGS[@]}
132 changes: 123 additions & 9 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,81 @@ def _build_shifted_tokens(
return full_tokens


def _response_keep_index(
total_lengths: list[int],
response_lengths: list[int],
qkv_format: str,
max_seq_lens: list[int] | None,
allgather_cp: bool,
device: torch.device,
T: int,
) -> torch.Tensor:
"""Positions that ``_extract_per_sample`` reads, as a flat 1-D LongTensor.

The cross-entropy in ``get_log_probs_and_entropy`` is only consumed on these
response-window positions; everything else in ``[T, V]`` is computed and then
discarded. Gathering ``logits`` down to these rows shrinks the dominant tensor
before CE; scattering the results back to full ``T`` leaves
``_extract_per_sample`` untouched.

The ranges below mirror ``_extract_per_sample`` branch-for-branch and in the
same order, so the two stay in lock-step (single source of truth for which
positions survive).
"""
cp_size = mpu.get_context_parallel_world_size()
ranges: list[tuple[int, int]] = []

if cp_size > 1 and not allgather_cp:
# zigzag CP: two windows per sample
pos = 0
for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)):
max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None
chunk_size_cp, chunks_offset, logits_offset, _tokens_offset = get_logits_and_tokens_offset_with_cp(
total_length, response_length, qkv_format, max_seq_len
)
lo0 = logits_offset[0][0] - chunks_offset[0][0]
hi0 = logits_offset[0][1] - chunks_offset[0][0]
lo1 = logits_offset[1][0] - chunks_offset[1][0]
hi1 = logits_offset[1][1] - chunks_offset[1][0]
ranges.append((pos + lo0, pos + hi0))
ranges.append((pos + chunk_size_cp + lo1, pos + chunk_size_cp + hi1))
pos += 2 * chunk_size_cp

elif allgather_cp:
cp_rank = mpu.get_context_parallel_rank()
chunk_start = cp_rank * T
chunk_end = chunk_start + T
seq_start = 0
for total_length, response_length in zip(total_lengths, response_lengths, strict=False):
prompt_length = total_length - response_length
logit_global_start = seq_start + prompt_length - 1
logit_global_end = seq_start + total_length - 1
s = max(logit_global_start, chunk_start)
e = min(logit_global_end, chunk_end)
if e > s:
ranges.append((s - chunk_start, e - chunk_start))
seq_start += total_length

else:
# cp1
if qkv_format == "thd":
offset = 0
for total_length, response_length in zip(total_lengths, response_lengths, strict=False):
end = offset + total_length
start = end - response_length
ranges.append((start - 1, end - 1))
offset += total_length
else: # bshd
for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)):
end = max_seq_lens[i] * i + total_length
start = end - response_length
ranges.append((start - 1, end - 1))

if not ranges:
return torch.zeros((0,), dtype=torch.long, device=device)
return torch.cat([torch.arange(s, e, device=device, dtype=torch.long) for s, e in ranges])


def _extract_per_sample(
log_prob_full: torch.Tensor,
entropy_full: torch.Tensor | None,
Expand Down Expand Up @@ -394,13 +469,22 @@ def get_log_probs_and_entropy(
with_entropy: bool = False,
non_loss_data: bool = True,
max_seq_lens: list[int] | None = None,
full_loss_mask: torch.Tensor | None = None,
) -> dict[str, list[torch.Tensor]]:
"""Compute per-token log-probabilities (and optionally entropy) on responses.

Computes on the **full** logits ``[T, V]`` tensor at once (instead of
per-sample slicing) so backward traverses ``[T, V]`` only once, then
extracts per-sample response portions.

With ``--log-probs-response-only`` the CE runs only on the response-window
rows (gathered out of ``[T, V]`` before CE and scattered back after), so the
dominant tensor shrinks from ``T`` to the number of response tokens ``T'``.
With ``--log-probs-loss-mask-only`` (and a ``full_loss_mask`` aligned to the
logits layout) it shrinks further to the ``loss_mask == 1`` rows; positions
dropped this way return a log-prob/entropy of 0 and so are only valid where
the downstream loss masks them out (policy-loss path).

When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()``
to avoid retaining the computation graph and to skip cloning.
"""
Expand Down Expand Up @@ -432,15 +516,44 @@ def get_log_probs_and_entropy(
T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp
)

# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
full_tokens,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]
# --- compute CE, optionally on a gathered subset of rows ---
if getattr(args, "log_probs_response_only", False):
# Only the response windows survive _extract_per_sample; gather them so CE
# runs on [T', V] instead of [T, V] (autograd's index_select backward
# scatters grads back to the dropped rows as zeros, which is exactly right).
keep_index = _response_keep_index(
total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp, device, T
)
if getattr(args, "log_probs_loss_mask_only", False) and full_loss_mask is not None:
mask_kept = full_loss_mask.reshape(-1).index_select(0, keep_index).to(torch.bool)
keep_index = keep_index[mask_kept]

logits_kept = logits.index_select(0, keep_index)
tokens_kept = full_tokens.index_select(0, keep_index)
lp_kept, ent_kept = calculate_log_probs_and_entropy(
logits_kept,
tokens_kept,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
)
lp_kept = lp_kept.squeeze(-1) # [T', 1] -> [T']

# scatter back to full length so _extract_per_sample is unchanged
log_prob_full = lp_kept.new_zeros(T).index_copy(0, keep_index, lp_kept)
entropy_full = None
if with_entropy:
entropy_full = ent_kept.new_zeros(T).index_copy(0, keep_index, ent_kept)
else:
# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
full_tokens,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]

# --- extract per-sample response portions ---
log_probs_list, entropy_list = _extract_per_sample(
Expand Down Expand Up @@ -481,6 +594,7 @@ def get_values(
with_entropy: bool = False,
non_loss_data: bool = True,
max_seq_lens: list[int] | None = None,
full_loss_mask: torch.Tensor | None = None, # unused; accepted so the shared forward_only partial fits
) -> dict[str, list[torch.Tensor]]:
"""Extract per-token value predictions over response tokens.

Expand Down
27 changes: 27 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@
logger = logging.getLogger(__name__)


def _mem_probe_enabled() -> bool:
return os.environ.get("SLIME_MEM_PROBE", "0") == "1" and torch.cuda.is_available()


def _log_train_step_mem_probe(rollout_id: int, step_id: int, start_allocated: int) -> None:
max_allocated = torch.cuda.max_memory_allocated()
logger.info(
"SLIME_MEM_PROBE train_one_step rollout_id=%s step_id=%s "
"allocated_start=%s allocated_peak=%s allocated_peak_delta=%s",
rollout_id,
step_id,
start_allocated,
max_allocated,
max_allocated - start_allocated,
)


def _disable_tqdm_for_non_main_rank() -> bool:
return not (
mpu.get_data_parallel_rank(with_context_parallel=True) == 0
Expand Down Expand Up @@ -352,6 +369,7 @@ def forward_step(
response_lengths=response_lengths,
with_entropy=args.use_rollout_entropy,
max_seq_lens=batch.get("max_seq_lens", None),
full_loss_mask=batch["full_loss_masks"],
)

# Turn on evaluation mode which disables dropout.
Expand Down Expand Up @@ -455,6 +473,11 @@ def train_one_step(
and gradient norm for logging.
"""
args = get_args()
mem_probe = _mem_probe_enabled()
mem_probe_start_allocated = 0
if mem_probe:
torch.cuda.reset_peak_memory_stats()
mem_probe_start_allocated = torch.cuda.memory_allocated()

# Set grad to zero.
for model_chunk in model:
Expand Down Expand Up @@ -601,7 +624,11 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
cp_size=mpu.get_context_parallel_world_size(),
dp_with_cp_group=mpu.get_data_parallel_group(with_context_parallel=True),
)
if mem_probe:
_log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated)
return loss_reduced, grad_norm
if mem_probe:
_log_train_step_mem_probe(rollout_id, step_id, mem_probe_start_allocated)
return {}, grad_norm


Expand Down
15 changes: 15 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ def add_train_arguments(parser):
parser.add_argument(
"--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory"
)
parser.add_argument(
"--log-probs-response-only",
action="store_true",
help="Gather only the response-window rows before the log-prob/entropy cross-entropy, "
"shrinking the [T, V] logits tensor to [T', V] (T' = response tokens). Results are identical.",
)
parser.add_argument(
"--log-probs-loss-mask-only",
action="store_true",
help="Further restrict the log-prob/entropy cross-entropy to loss_mask==1 rows. Requires "
"--log-probs-response-only; only valid on the policy-loss path (masked positions return 0).",
)
parser.add_argument(
"--only-train-params-name-list",
type=str,
Expand Down Expand Up @@ -1851,6 +1863,9 @@ def slime_validate_args(args):
assert args.use_dynamic_batch_size, "--balance-by-flops requires --use-dynamic-batch-size"
args.balance_data = True

if getattr(args, "log_probs_loss_mask_only", False):
assert args.log_probs_response_only, "--log-probs-loss-mask-only requires --log-probs-response-only"

if args.eps_clip_high is None:
args.eps_clip_high = args.eps_clip

Expand Down
Loading
Loading