Skip to content

Triton reverse_kl loss kernel is ~3x slower than torch.compile (single GPU) #497

@jlamypoirier

Description

@jlamypoirier

Summary

The Triton `reverse_kl` loss kernel (`triton_reverse_kl_forward_backward` in `fast_llm/functional/triton/entropy_loss.py`) is ~2.9× slower than `torch.compile` at vocab=32K on a single GPU. Note: in tensor-parallel training logits are sharded (`vocab / tp_size` per rank) so the gap narrows at higher TP degrees — and at large vocab TP is the only feasible approach (see #triton-tp-benchmark plan).

Benchmark results (H100 SXM, bf16, single GPU, fwd+bwd)

Shape (tokens × vocab) pytorch_compiled fast_llm_triton Triton memory
4 Ki × 32 Ki 485 GB/s, 14.5% BW, Δpeak 0.75 GiB 169 GB/s, 5.1% BW, Δpeak 0.25 GiB 3× less
4 Ki × 64 Ki 481 GB/s, 14.4% BW, Δpeak 1.50 GiB anomalous timing*
4 Ki × 128 Ki 483 GB/s, 14.4% BW, Δpeak 3.00 GiB 389 GB/s, 11.6% BW, Δpeak 1.00 GiB 3× less

*The 64 Ki case recorded a negative backward time (measurement artifact).

Triton uses 3× less activation memory by fusing fwd+bwd into one pass without intermediate tensors — but that doesn't compensate for the throughput deficit at 32K vocab.

Root cause

`block_size = min(next_power_of_2(n_cols), 32768)` → at vocab=32K, block_size=32K. With 512 threads/block each thread holds 64 fp32 values in registers. H100 has 65536 registers/SM → only ~2 blocks fit per SM → ~50% warp occupancy → DRAM latency not hidden.

Three compounding factors vs `cross_entropy_labels` (which achieves 69–96% BW):

  1. Dual softmax in forward — `triton_reverse_kl_forward_from_distribution` runs `triton_fused_softmax_iter_base` on both logits and target simultaneously. Extra live registers (target tile, target_max, target_sum_exp) squeeze occupancy further.
  2. Heavier backward formula — reverse KL grad needs an extra `log` + scalar broadcast vs cross-entropy's `p − q`.
  3. torch.compile advantage — for CE, the unfused backward is slow enough that Triton wins despite low occupancy. For reverse KL, `F.kl_div(target.log_softmax(), logits.softmax())` produces two high-occupancy softmax kernels + pointwise, each well-tuned individually.

Fix options

Option A — Lower the block_size cap for distribution kernels (e.g. 8192 instead of 32768): 4× more blocks/SM at the cost of 3–4 extra re-reads of logits+target in the backward. Empirically likely a net win on H100. Try caps of 4096, 8192, 16384 and benchmark.

Option B — Two-pass fwd/bwd with cached stats: the `group is not None` (distributed) path already does this split via the `max_logits_ptr is not None` branch — a separate forward kernel stores per-row max/sum stats to DRAM, then the backward kernel reloads them without redoing the softmax. The fix for the non-distributed path is to invoke those same existing forward + backward kernels separately rather than the fused `forward_backward` kernel. Both passes run at their independently optimal block sizes; the backward only reads logits+target once. Any fix should stay consistent with the distributed path or unify them.

Expected outcome (after fix)

variant current estimated after fix
pytorch_compiled 1.634 ms, 14.7% BW — (baseline)
fast_llm_triton 4.669 ms, 5.1% BW ~1.0–1.5 ms, ~20–30% BW

Triton should match or beat compiled, especially at larger vocab sizes where the fused single-pass advantage (no intermediate softmax tensors written to DRAM) matters more.

Notes

`cross_entropy_logits` (CE with a soft target distribution rather than hard labels) has the same dual-softmax register pressure and the same fix applies. It is currently ~1.5× faster than compiled vs 3× for CE with labels, consistent with the same occupancy issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions