Skip to content

Adds initial Keras Orbax checkpointer V2 implementation.#74

Merged
copybara-service[bot] merged 1 commit into
mainfrom
test_774946771
Jun 11, 2026
Merged

Adds initial Keras Orbax checkpointer V2 implementation.#74
copybara-service[bot] merged 1 commit into
mainfrom
test_774946771

Conversation

@copybara-service

Copy link
Copy Markdown

Adds initial Keras Orbax checkpointer V2 implementation.

First step in creating a memory efficient Keras + Jax checkpointer that uses nested PyTrees instead of flat tuples to enable model surgery.

  • Checkpoints the serialized model config as metadata.
  • Upgrades the checkpointing logic to the new Orbax API.
  • Writes checkpoints as a dict instead of a tuple.
  • Removes unnecessary expensive jax_state_sync calls.

Reverts changelist 793734230

Reverts changelist 793734230

PiperOrigin-RevId: 930775912
@copybara-service copybara-service Bot merged commit 8489b48 into main Jun 11, 2026
89 of 115 checks passed
@copybara-service copybara-service Bot deleted the test_774946771 branch June 11, 2026 22:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants