Skip to content

Fix full-loss comparison cell in Grokking demo#1378

Merged
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/grokking-demo-full-loss
Jun 11, 2026
Merged

Fix full-loss comparison cell in Grokking demo#1378
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/grokking-demo-full-loss

Conversation

@robbiebusinessacc

Copy link
Copy Markdown
Contributor

Fixes #543

The "Restricted Loss" section of demos/Grokking_Demo.ipynb calls
loss_fn(all_logits, labels), but by that point all_logits has been
rearranged (in the Logit Periodicity section) into a (p, p, d_vocab)
grid for the Fourier analysis. loss_fn's 3-D branch assumes (batch, pos, d_vocab) and takes logits[:, -1], producing a (113, 113)
tensor whose gather against the 12,769 labels raises the RuntimeError
reported in the issue. The crash is deterministic — it happens
regardless of how far training got.

This changes the cell to loss_fn(original_logits, labels).
original_logits is recomputed via model.run_with_cache(dataset) two
cells above in the same section and is the full-dataset loss the cell
intends to print, matching the existing loss_fn(original_logits, labels) usage earlier in the notebook. The cell's stored RuntimeError
output is cleared as well.

Verified on CPU with an untrained model built from the notebook's own
config: the old expression reproduces the exact crash from the issue,
and the new expression returns a value identical to
torch.nn.functional.cross_entropy on the final-token logits. The
separate out-of-memory error mentioned at the end of the issue (Colab
T4) is a resource issue and out of scope here.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my
    feature works — this demo notebook is excluded from the CI notebook
    matrix, so there is no test harness for it; verification was done by
    executing the affected code path directly as described above
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would
    affect backward compatibility

The Restricted Loss section called loss_fn(all_logits, labels), but
all_logits had been rearranged earlier into a (p, p, d_vocab) grid for
the logit periodicity analysis. loss_fn's 3-D branch assumes
(batch, pos, d_vocab) and takes logits[:, -1], producing a (p, p)
tensor that crashes the gather against the p*p labels (TransformerLensOrg#543).

Use original_logits instead, which is recomputed just above and is the
same full-dataset loss the cell intends to print. Also clear the stored
RuntimeError output from the cell.
@jlarson4

Copy link
Copy Markdown
Collaborator

Excellent work on this @robbiebusinessacc, merging! Thanks for taking this on

@jlarson4 jlarson4 merged commit f3a0ce4 into TransformerLensOrg:dev Jun 11, 2026
49 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants