Fix full-loss comparison cell in Grokking demo#1378
Merged
jlarson4 merged 1 commit intoJun 11, 2026
Merged
Conversation
The Restricted Loss section called loss_fn(all_logits, labels), but all_logits had been rearranged earlier into a (p, p, d_vocab) grid for the logit periodicity analysis. loss_fn's 3-D branch assumes (batch, pos, d_vocab) and takes logits[:, -1], producing a (p, p) tensor that crashes the gather against the p*p labels (TransformerLensOrg#543). Use original_logits instead, which is recomputed just above and is the same full-dataset loss the cell intends to print. Also clear the stored RuntimeError output from the cell.
Collaborator
|
Excellent work on this @robbiebusinessacc, merging! Thanks for taking this on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #543
The "Restricted Loss" section of
demos/Grokking_Demo.ipynbcallsloss_fn(all_logits, labels), but by that pointall_logitshas beenrearranged (in the Logit Periodicity section) into a
(p, p, d_vocab)grid for the Fourier analysis.
loss_fn's 3-D branch assumes(batch, pos, d_vocab)and takeslogits[:, -1], producing a(113, 113)tensor whose gather against the 12,769 labels raises the
RuntimeErrorreported in the issue. The crash is deterministic — it happens
regardless of how far training got.
This changes the cell to
loss_fn(original_logits, labels).original_logitsis recomputed viamodel.run_with_cache(dataset)twocells above in the same section and is the full-dataset loss the cell
intends to print, matching the existing
loss_fn(original_logits, labels)usage earlier in the notebook. The cell's storedRuntimeErroroutput is cleared as well.
Verified on CPU with an untrained model built from the notebook's own
config: the old expression reproduces the exact crash from the issue,
and the new expression returns a value identical to
torch.nn.functional.cross_entropyon the final-token logits. Theseparate out-of-memory error mentioned at the end of the issue (Colab
T4) is a resource issue and out of scope here.
Type of change
Checklist:
feature works — this demo notebook is excluded from the CI notebook
matrix, so there is no test harness for it; verification was done by
executing the affected code path directly as described above
affect backward compatibility