Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a0c3541
ci: fix artifact upload conflict and update coverage downloader
g-husam Mar 24, 2026
e1e07cf
feat(adapter/nemo_rl): add NeMo RL adapter and wrapper util
g-husam Mar 13, 2026
2de73a1
Apply gemini suggestions from code review
g-husam Mar 19, 2026
a7404ab
replace monkey patching with extension class approach, with tests
g-husam Mar 25, 2026
e97a0f5
Find a way to run nemo_rl tests only on python 3.12 and cleanup test …
g-husam Mar 25, 2026
b240805
Refactor: use pyproject.toml markers for conditional dependencies and…
g-husam Mar 25, 2026
92095cf
add comment
g-husam Mar 25, 2026
d02bbab
Address PR feedback: validate save_strategy, fix user guide examples,…
g-husam Mar 25, 2026
a482189
Rename nemo_rl/wrapper_util.py to wrapper_util_rl.py and fix usages
g-husam Mar 25, 2026
054d980
Auto format tests/adapter/nemo_rl/test_checkpoint_manager.py using ru…
g-husam Mar 25, 2026
2c17326
Relax torch version requirement to >=2.8.0 to resolve dependency conf…
g-husam Mar 25, 2026
0d06d6a
use explicit pip-version 24.0.1 in build job
g-husam Mar 26, 2026
45cfed8
set pip-version to 24.0 (which actually exists)
g-husam Mar 26, 2026
d2ae870
chore(build): rebase on bifurcated profiles and fix dev-nemo-rl depen…
g-husam Mar 26, 2026
1bbf9f8
ci: run all tests on python 3.12 and exclude nemo_rl tests on 3.10
g-husam Mar 27, 2026
ef24d26
ci: bifurcate tests and coverage between python 3.10 and 3.12
g-husam Mar 27, 2026
173e8e2
ci: recursively clone and install nemo_rl for python 3.12
g-husam Mar 27, 2026
7731522
remove local dir
g-husam Mar 27, 2026
8e0363e
ci: recursively clone and install nemo_rl in editable mode for python…
g-husam Mar 27, 2026
328423d
ci: use NeMo RL docker container for python 3.12 build
g-husam Mar 27, 2026
22c35dc
rebase and add comment
g-husam Mar 27, 2026
55a1581
ci: use published NeMo RL image for builds and unify build jobs
g-husam Mar 28, 2026
ea58b7f
ci: add debug step to diagnose NeMo RL environment
g-husam Mar 28, 2026
23ed295
ci: set PYTHONPATH for NeMo RL build to include Megatron-LM
g-husam Mar 28, 2026
cc8d648
ci: expand debug step to find megatron.bridge
g-husam Mar 28, 2026
f8dee0e
ci: include Megatron-Bridge in PYTHONPATH for NeMo RL build
g-husam Mar 28, 2026
3c4d8c1
ci: expand debug step to find modelopt and requirements
g-husam Mar 28, 2026
7ce7004
ci: very broad debug to find modelopt
g-husam Mar 28, 2026
3e0873d
ci: fix missing dependencies in NeMo RL container build
g-husam Mar 28, 2026
6a36d62
ci: install 3rdparty components as editable packages in NeMo RL build
g-husam Mar 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,46 @@ jobs:
build:
# Use larger machine with more disk space
runs-on: ubuntu-24.04-32core
container:
# Use a container image if specified in the matrix, otherwise run on the runner host.
image: ${{ matrix.container }}
options: --user root
strategy:
fail-fast: false
matrix:
include:
- python-version: "3.10"
- name: "standard"
python-version: "3.10"
profile: "dev-nemo"
- python-version: "3.12"
test-target: "."
test-filter: "-m 'not nemo_rl'"
coverage-filter: "--omit='src/ml_flashpoint/adapter/nemo_rl/*'"
artifact-name: "coverage-reports"
run-cpp-coverage: true
- name: "nemo-rl"
# NeMo RL image already has Python 3.12 and its dependencies pre-installed.
container: "nvcr.io/nvidia/nemo-rl:v0.5.0"
profile: "dev-nemo-rl"
test-target: "tests/adapter/nemo_rl"
coverage-filter: "--include='src/ml_flashpoint/adapter/nemo_rl/*'"
artifact-name: "coverage-reports-nemo-rl"
run-cpp-coverage: false
# Use a login shell to ensure container environment profiles are correctly loaded.
shell: "bash -l {0}"
env:
PYTHON_FAIL_UNDER: 90
CPP_FAIL_UNDER: 80
permissions:
contents: read # Required for actions/checkout
defaults:
run:
shell: ${{ matrix.shell || 'bash {0}' }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6

- name: Set up Python ${{ matrix.python-version }}
# Only needed for standard builds; NeMo RL image already has a pre-configured Python environment.
if: ${{ !matrix.container }}
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
Expand All @@ -50,29 +73,47 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Clean up apt cache to save space
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
# Clean up apt cache to save space (only on host)
if [ -z "${{ matrix.container }}" ]; then
sudo apt-get clean
sudo rm -rf /var/lib/apt/lists/*
fi
# Install missing dependencies if in NeMo RL container
if [ -n "${{ matrix.container }}" ]; then
# Install known missing dependencies
pip install nvidia-modelopt

# Install 3rdparty components as editable packages
# This is more robust than PYTHONPATH as it handles namespace packages and dependencies.
pip install -e /opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/
pip install -e /opt/nemo-rl/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge/
pip install -e /opt/nemo-rl/3rdparty/Automodel-workspace/Automodel/
fi
df -h
# Install python dependencies (with coverage enabled)
echo -e "\n##### Running pip install #####"
pip install -e '.[${{ matrix.profile }}]' --config-settings=cmake.args="-DENABLE_COVERAGE=ON"

# Verify installation
if [ -n "${{ matrix.container }}" ]; then
pip list
fi

- run: df -h

- name: Test with pytest with coverage enabled
run: |
# Run all tests with coverage (Python and C++)
# Run tests based on python version target and filter
echo -e "\n##### Running Python tests with coverage #####"
coverage run --source=src/ml_flashpoint --branch -m pytest -v -s
python -m coverage run --source=src/ml_flashpoint --branch -m pytest -v -s ${{ matrix.test-filter }} ${{ matrix.test-target }}

- name: Check Python test coverage
run: |
# Verify python coverage thresholds
# Verify python coverage thresholds with specific filter
echo -e "\n##### Generating Python coverage XML #####"
coverage xml -o python-coverage.xml
python -m coverage xml -o python-coverage.xml ${{ matrix.coverage-filter }}
echo -e "\n##### Verifying Python coverage thresholds #####"
coverage report --fail-under=${{ env.PYTHON_FAIL_UNDER }}
python -m coverage report --fail-under=${{ env.PYTHON_FAIL_UNDER }} ${{ matrix.coverage-filter }}

- name: Python Coverage Summary
uses: irongut/CodeCoverageSummary@51cc3a756ddcd398d447c044c02cb6aa83fdae95 # ratchet:irongut/CodeCoverageSummary@v1.3.0
Expand Down Expand Up @@ -104,6 +145,7 @@ jobs:
path: python-code-coverage-results.md

- name: Check C++ test coverage
if: matrix.run-cpp-coverage
run: |
# Run C++ coverage check
echo -e "\n##### Running C++ coverage check #####"
Expand All @@ -125,8 +167,8 @@ jobs:
--fail-under-line=${{ env.CPP_FAIL_UNDER }}

- name: C++ Coverage Summary
if: matrix.run-cpp-coverage
uses: irongut/CodeCoverageSummary@51cc3a756ddcd398d447c044c02cb6aa83fdae95 # ratchet:irongut/CodeCoverageSummary@v1.3.0
if: always() # Run even if threshold check above fails
with:
filename: cxx-coverage.xml
badge: true
Expand All @@ -139,15 +181,15 @@ jobs:
thresholds: '${{ env.CPP_FAIL_UNDER }} 40'

- name: Add C++ Coverage Title
if: always()
if: always() && matrix.run-cpp-coverage
run: |
if [ -f code-coverage-results.md ]; then
echo '### C++ Code Coverage Summary' | cat - code-coverage-results.md > temp && mv temp cpp-code-coverage-results.md
fi

- name: Add C++ Coverage PR Comment
if: false && matrix.run-cpp-coverage # TODO: remove when new workflow confirmed to work
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # ratchet:marocchino/sticky-pull-request-comment@v2
if: false # TODO: remove when new workflow confirmed to work
with:
header: cpp-coverage
recreate: true
Expand All @@ -163,9 +205,7 @@ jobs:
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
if: always()
with:
# Use 'coverage-reports' for 3.10 to maintain compatibility with the post-coverage workflow on main.
# Other versions get a unique suffix to avoid upload conflicts in the matrix.
name: ${{ matrix.python-version == '3.10' && 'coverage-reports' || format('coverage-reports-{0}', matrix.python-version) }}
name: ${{ matrix.artifact-name }}
if-no-files-found: warn # Default, but setting explicitly for awareness as non-PRs won't have pr_number.txt
path: |
htmlcov/
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/post-coverage-comment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
name: coverage-reports
github-token: ${{ secrets.GITHUB_TOKEN }}
run-id: ${{ github.event.workflow_run.id }}
merge-multiple: true

- name: Check for coverage files
# We use an explicit step to check for file existence and set outputs.
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ cmake-build-debug/

# Linters
.ruff_cache

# AI Agents
.gemini
84 changes: 84 additions & 0 deletions docs/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,90 @@ New jobs however should have an independent job ID, so as not to conflict with p
1. It is recommended to supply the `MLFlashpointCheckpointCallback` with the standard checkpoint strategy's interval (its `every_n_steps` configuration), so ML Flashpoint can skip its own saves when the standard strategy will save.
This reduces blocking time by avoiding duplicate work, at the cost of having a longer write time for that step.

### NeMo RL

The NeMo RL framework does not use PyTorch Lightning natively, and instead uses its own `CheckpointManager` and policy workers. ML Flashpoint provides a specialized wrapper adapter designed to inject fast checkpointing transparently into your training loops.

#### Imports

Import the NeMo RL wrapper provided by the ML Flashpoint adapter:

```python
from ml_flashpoint.adapter.nemo_rl.wrapper_util_rl import wrap_rl_components_with_mlflashpoint
from ml_flashpoint.core.checkpoint_loader import MLFlashpointCheckpointLoader
```

Additionally, you will need to instantiate your preferred save strategy to tell the manager how to commit ML Flashpoint blobs:

```python
from ml_flashpoint.adapter.megatron import MLFlashpointMegatronAsyncSaveStrategy
```

#### Recipe Changes

NeMo RL organizes its training entry points into Python scripts (like `examples/run_grpo.py`), which orchestrate the initialization steps and are driven heavily by configurations in YAML files.

Instead of modifying the upstream framework loops themselves (such as `async_grpo_train` in `nemo_rl/algorithms/grpo.py`), you should wrap the checkpointer instantiation within these NeMo RL script entry points.

```python
# 1. Your original NeMo RL initializers
# Typically instantiated via setup() or directly:
policy = Policy(cluster=train_cluster, config=policy_config, ...)
checkpointer = CheckpointManager(
checkpoint_dir=args.checkpoint_dir,
metric_name=args.metric_name, ...
)

# 2. Add the ML Flashpoint dual manager
flashpoint_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(...)
checkpointer = wrap_rl_components_with_mlflashpoint(
checkpointer=checkpointer,
# Some tmpfs path for this job like /tmp/mlf/job-12345
flashpoint_base_container=_get_my_mlf_base_path(),
standard_save_period=1000, # Dictates when standard saves execute
save_strategy=flashpoint_save_strategy,
checkpoint_loader=MLFlashpointCheckpointLoader(...),
)
Comment thread
g-husam marked this conversation as resolved.

# 3. Supply the wrapper backwards as if it were the standard checkpointer
# For example, within GRPO:
async_grpo_train(
policy=policy,
checkpointer=checkpointer, # Dual checkpointer takes over routing
...
)
```

#### Limitations / Requisites

1. **Standard `save_period` override:** You must coordinate the standard save properties. The `save_period` configured inside your NeMo RL configurations (typically in the YAML config under `checkpointing: save_period: ...` or [see an example here](https://github.com/NVIDIA-NeMo/RL/blob/main/examples/configs/grpo_math_1B.yaml)) should now be set aggressively low (e.g. `1` or `10`), dictating how frequently *ML Flashpoint* triggers.
1. `standard_save_period` dictates how frequently your standard long-term persistence will actually run instead. For instance, configuring NeMo RL YAML `save_period: 10` and injecting `standard_save_period=1000` via our wrapper means ML Flashpoint saves every 10 steps, and standard checkpoints save every 1000 steps.

#### NeMo RL Configuration (Worker Side)

When using the custom worker extension (`MLFlashpointMegatronPolicyWorker`), it reads configuration from `self.cfg` (which is the `PolicyConfig` TypedDict passed during initialization).

You can define the `ml_flashpoint` configuration block in your recipe or config file. It should be a dictionary nested within the policy configuration.

##### Configuration Schema

| Field | Type | Required | Description |
| :--- | :--- | :--- | :--- |
| `enabled` | `bool` | No (default `True`) | Enable/disable ML Flashpoint on the worker. |
| `base_container` | `str` | **Yes** | The base directory (typically in `tmpfs`) for ML Flashpoint checkpoints. |
| `write_thread_count` | `int` | No (default `1`) | Number of threads for asynchronous writing. |
| `buffer_size_bytes` | `int` | No (default `16 GB`) | Size of the shared memory buffers in bytes. |

Example configuration in a YAML or dict:

```yaml
policy:
ml_flashpoint:
# enabled: true # default
base_container: "/tmp/mlf-checkpoints/job-12345"
# buffer_size_bytes: 17179869184 # default (16 GB)
```

### Megatron-LM

Code: See the [`ml_flashpoint.adapter.megatron`](https://github.com/google/ml-flashpoint/tree/main/src/ml_flashpoint/adapter/megatron) package.
Expand Down
15 changes: 11 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dependencies = [

# An extra for users who want to use this library with just PyTorch.
# Installed via: `pip install ml-flashpoint[pytorch]`
pytorch = ["torch==2.8.0"]
pytorch = ["torch>=2.8.0"]
# An extra for users who want to use this library with its Megatron-LM adapter.
# Installed via: `pip install ml-flashpoint[megatron]`
megatron = ["megatron_core==0.13.1"]
Expand All @@ -67,6 +67,13 @@ nemo = [
"ml_flashpoint[pytorch]",
"nemo_toolkit[all]==2.4.0",
]
# An extra for users who want to use this library with NeMo RL.
# This is kept empty because the NeMo RL environment (including its complex dependencies
# like PyTorch and Megatron-Core) is provided by the base container image in CI.
# This prevents pip from attempting to resolve or update these dependencies, which
# can cause version conflicts and break the environment.
# Installed via: `pip install ml-flashpoint[nemo-rl]`
nemo-rl = []

# An extra for generating the documentation site.
# Installed via: `pip install ml-flashpoint[docs]`
Expand Down Expand Up @@ -106,9 +113,7 @@ dev-nemo = [
# Defines a "dev-nemo-rl" extra for NeMo RL development (typically uses Python 3.12+).
# Installed via: `pip install -e .[dev-nemo-rl]`
dev-nemo-rl = [
"ml-flashpoint[dev-nemo]",
# TODO: uncomment below and remove line above when nemo-rl profile is added
#"ml-flashpoint[dev-base,nemo-rl]",
"ml-flashpoint[dev-base,nemo-rl,docs]",
]

# Defines a "dev" extra for setting up a development environment.
Expand Down Expand Up @@ -204,13 +209,15 @@ exclude = [
[tool.pytest.ini_options]
norecursedirs = [
".git",
".gemini",
"build/**/_deps",
".gemini",
".worktrees",
]
markers = [
"e2e: marks tests as end-to-end",
"smoke: quick subset of tests",
"nemo_rl: marks tests for NeMo RL adapter",
]

# ===================================================================
Expand Down
17 changes: 17 additions & 0 deletions src/ml_flashpoint/adapter/nemo_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .checkpoint_manager import MLFlashpointRLCheckpointManager

__all__ = ["MLFlashpointRLCheckpointManager"]
Loading
Loading