Skip to content

Monolithic fused head-loss kernel for combinable losses + metrics #507

@jlamypoirier

Description

@jlamypoirier

Problem

Each language-model head loss/metric runs its own softmax pass over the full vocab. When losses are combined, the work is duplicated:

  • cross-entropy + z-loss — both compute sum_exp_logits over the same logits.
  • cross-entropy + distillation — both compute the student softmax (and additionally the teacher softmax for distillation).
  • GRPO + extra metrics (PR grpo: add policy-gradient metrics behind metrics enum #494) — fused_grpo_loss_forward_backward computes logits_norm, exp_logits, sum_exp_logits, and predicted_logits; compute_policy_gradient_metrics then recomputes all of them on the same logits. With compute_entropy_metric=True, a third softmax pass runs on top.

Each pass also issues its own tensor-parallel all-reduces on logits_max / sum_exp_logits, multiplying the comm cost.

@torch.compile does not fuse across separate decorated functions, so the redundant work is real both in compute and memory.

Proposed direction

A single "monolithic" head-loss kernel (torch.compile and/or triton) that:

  • Takes a config / flag-set describing which losses and metrics to emit (CE, z-loss, distillation, GRPO clipped objective, GRPO ratio/KL/clamp/advantage stats, entropy, ...).
  • Runs softmax (and TP all-reduce) once over the logits.
  • Emits all requested scalars and the combined gradient in one pass.

This subsumes fused_grpo_loss_forward_backward, _fused_cross_entropy_base_from_*, fused_softmax_base, compute_policy_gradient_metrics, and the entropy chunking in PR #494.

Out of scope

  • Implementation details (triton vs torch.compile, kernel layout) — left for the design phase.
  • The current PRs that hit this redundancy (notably grpo: add policy-gradient metrics behind metrics enum #494) are not blocked on this issue and should land with their existing structure; this is the longer-term consolidation.

Motivation

Surfaced during review of #494, but the underlying redundancy predates it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions