Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
[submodule "research-environments"]
path = deps/research-environments
url = git@github.com:PrimeIntellect-ai/research-environments.git
[submodule "configs/private"]
path = configs/private
url = git@github.com:PrimeIntellect-ai/research-configs.git
[submodule "pydantic-config"]
path = deps/pydantic-config
url = https://github.com/PrimeIntellect-ai/pydantic-config
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ source $HOME/.local/bin/env
uv sync --all-extras
```

4.1. On aarch64 hosts: build flash-attn from source for your GPU

> *NOTE*: aarch64 has no prebuilt flash-attn wheel. This step compiles the CUDA extension for your local GPU (~20-30 minutes). Compute capability is auto-detected from `nvidia-smi`; override with `TORCH_CUDA_ARCH_LIST=9.0` (Hopper) / `10.0` (Blackwell) if needed.
> *NOTE*: After this step, you can't run `uv sync --all-extras` or `uv run` as it will uninstall the package, you can avoid it by running `uv sync --inexact` or `uv run --no-sync`.

```bash
bash scripts/docker-arm64-post-install.sh
```

3.1. Optional: Install Flash Attention 3 (on Hopper GPUs only, for flash_attention_3 attention backend)

> *NOTE*: This step will take a while, as it builds the Flash Attention 3 extension from source, as it has no wheels prebuilt.
Expand Down
1 change: 0 additions & 1 deletion configs/private
Submodule private deleted from 70c350
2 changes: 1 addition & 1 deletion deps/research-environments
2 changes: 1 addition & 1 deletion deps/verifiers
Submodule verifiers updated 210 files
16 changes: 16 additions & 0 deletions docs/algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,22 @@ kwargs = { eps = 1e-8 }

`AdvantageInputs.rollouts` is a list of `verifiers.RolloutOutput`, so you have access to the full rollout (turns, tool calls, custom metadata) — not just the reward. Use this for anything reward-shaping-like that needs trajectory context.

### Per-Env Advantage

`advantage` can be set per training environment. Each env inherits the top-level `[orchestrator.advantage]` when it doesn't set its own, so mixed-env runs can give each env its own advantage computation:

```toml
[orchestrator.advantage]
type = "default" # the default every env inherits unless it overrides

[[orchestrator.train.env]]
id = "math-env" # inherits the default above

[[orchestrator.train.env]]
id = "agent-env"
advantage = { type = "custom", import_path = "my_module.normalized_advantage" }
```

## Filters

Filters drop rollouts between scoring and training. Built-ins (composable):
Expand Down
111 changes: 60 additions & 51 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,49 @@ def _deprecate_max_tokens(cls, data: Any) -> Any:
return data


class TokensLengthPenaltyConfig(BaseConfig):
type: Literal["tokens"] = "tokens"

completion_weight: float = Field(1.0, ge=0, allow_inf_nan=False)
"""Weight on model completion tokens. Finite and non-negative."""

tool_response_weight: float = Field(1.0, ge=0, allow_inf_nan=False)
"""Weight on tool-response tokens (read from the rollout's ``*_total_tool_response_tokens`` harness metric; 0 if absent). Finite and non-negative."""


class TurnsLengthPenaltyConfig(BaseConfig):
type: Literal["turns"] = "turns"


LengthPenaltyConfig: TypeAlias = Annotated[
TokensLengthPenaltyConfig | TurnsLengthPenaltyConfig,
Field(discriminator="type"),
]


class DefaultAdvantageConfig(BaseConfig):
type: Literal["default"] = "default"

length_penalty: LengthPenaltyConfig | None = None
"""Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0."""


class CustomAdvantageConfig(BaseConfig):
type: Literal["custom"] = "custom"

import_path: str
"""Import path to the advantage function (e.g. ``my_module.my_advantage``)."""

kwargs: dict[str, Any] = Field(default_factory=dict)
"""Kwargs forwarded to the advantage function."""


AdvantageConfig: TypeAlias = Annotated[
DefaultAdvantageConfig | CustomAdvantageConfig,
Field(discriminator="type"),
]


class EnvConfig(BaseConfig):
id: str = "reverse-text"
"""Registered verifiers environment ID (e.g. ``math-env``, ``primeintellect/math-env``). May include an ``@version`` suffix for installation."""
Expand Down Expand Up @@ -214,6 +257,11 @@ class TrainEnvConfig(EnvConfig):
"""Rollouts generated per example for GRPO group-relative advantages.
Inherits from ``orchestrator.group_size`` when unset."""

advantage: AdvantageConfig | None = None
"""Advantage strategy for this env's GRPO groups. Inherits from the top-level
``orchestrator.advantage`` when unset; set a different ``default``/``custom``
config to give this env its own advantage computation."""


class EvalEnvConfig(EnvConfig):
sampling: EvalSamplingConfig = EvalSamplingConfig()
Expand Down Expand Up @@ -374,49 +422,6 @@ class CheckpointConfig(BaseConfig):
"""Skip loading the progress from checkpoint."""


class TokensLengthPenaltyConfig(BaseConfig):
type: Literal["tokens"] = "tokens"

completion_weight: float = Field(1.0, ge=0, allow_inf_nan=False)
"""Weight on model completion tokens. Finite and non-negative."""

tool_response_weight: float = Field(1.0, ge=0, allow_inf_nan=False)
"""Weight on tool-response tokens (read from the rollout's ``*_total_tool_response_tokens`` harness metric; 0 if absent). Finite and non-negative."""


class TurnsLengthPenaltyConfig(BaseConfig):
type: Literal["turns"] = "turns"


LengthPenaltyConfig: TypeAlias = Annotated[
TokensLengthPenaltyConfig | TurnsLengthPenaltyConfig,
Field(discriminator="type"),
]


class DefaultAdvantageConfig(BaseConfig):
type: Literal["default"] = "default"

length_penalty: LengthPenaltyConfig | None = None
"""Correctness-gated length penalty. ``tokens`` shapes by weighted token cost; ``turns`` shapes by trajectory turn count; None disables shaping. In mixed groups, lower-cost correct rollouts get amplified advantage (up to 2x), higher-cost correct rollouts are unchanged, incorrect untouched. In all-correct groups, below-average-cost rollouts get advantage in [0, 1], others get 0."""


class CustomAdvantageConfig(BaseConfig):
type: Literal["custom"] = "custom"

import_path: str
"""Import path to the advantage function (e.g. ``my_module.my_advantage``)."""

kwargs: dict[str, Any] = Field(default_factory=dict)
"""Kwargs forwarded to the advantage function."""


AdvantageConfig: TypeAlias = Annotated[
DefaultAdvantageConfig | CustomAdvantageConfig,
Field(discriminator="type"),
]


# Flags rare tokens generated at high entropy (Section 5.2, https://arxiv.org/abs/2510.02387).
class GibberishFilterConfig(BaseConfig):
type: Literal["gibberish"] = "gibberish"
Expand Down Expand Up @@ -517,7 +522,7 @@ class OrchestratorConfig(BaseConfig):
"""Typed renderer config (``renderers.RendererConfig`` discriminated
union). Defaults to ``"auto"``, which resolves from
``tokenizer.name_or_path`` via ``MODEL_RENDERER_MAP``. ``None``
opts into MITO (``openai_chat_completions``); SFT mode forces this."""
opts into MITO (``openai_chat_completions``)."""

pool_size: int | None = Field(None, ge=1)
"""Number of renderer slots shared across concurrent rollouts. Bump
Expand Down Expand Up @@ -754,11 +759,10 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def _force_no_renderer_for_sft(self):
"""SFT rolls out via the teacher's plain chat-completions endpoint; the
renderer client doesn't apply. Force ``renderer=None`` so the user
doesn't have to remember to set it. Declared before the renderer
validators below so they see the corrected value."""
if self.training_mode == "sft":
"""Teacher-backed SFT rolls out via the teacher's plain chat-completions
endpoint; the renderer client doesn't apply. When no teacher is
configured, SFT uses the student rollout path and keeps the renderer."""
if self.training_mode == "sft" and self.teacher is not None:
self.renderer = None
return self

Expand All @@ -768,8 +772,8 @@ def validate_training_mode(self):
has_teacher = self.teacher is not None
if self.training_mode == "rl" and has_teacher:
raise ValueError("orchestrator.teacher must not be set when training_mode = 'rl'.")
if self.training_mode in ("opd", "sft") and not has_teacher:
raise ValueError(f"orchestrator.teacher must be configured when training_mode = '{self.training_mode}'.")
if self.training_mode == "opd" and not has_teacher:
raise ValueError("orchestrator.teacher must be configured when training_mode = 'opd'.")
return self

@model_validator(mode="after")
Expand Down Expand Up @@ -876,6 +880,11 @@ def resolve_batching(self):
if "group_size" not in env_cfg.model_fields_set:
env_cfg.group_size = self.group_size

# Propagate the top-level ``advantage`` into each train env that didn't set its own.
for env_cfg in self.train.env:
if "advantage" not in env_cfg.model_fields_set:
env_cfg.advantage = self.advantage

# Resolve train env num_workers from max_inflight_rollouts
for env_cfg in self.train.env:
if env_cfg.num_workers == "auto":
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ readme = "README.md"
requires-python = "~=3.12.0"
dependencies = [
"prime-rl-configs",
"prime-pydantic-config",
"beartype>=0.21.0",
"datasets>=4.0.0",
"jaxtyping>=0.3.2",
Expand All @@ -27,6 +28,7 @@ dependencies = [
"aiolimiter>=1.2.1",
"tenacity>=8.2.0",
"openai>=1.106.1",
"orjson>=3.11.0",
"rich>=14.0.0",
"setproctitle>=1.3.0",
"uvloop>=0.21.0",
Expand Down Expand Up @@ -77,7 +79,6 @@ envs = [
"math-python",
"math500",
"mini-swe-agent-plus",
"mini-swe-agent-plus-rlm",
"mmlu-pro",
"opencode-cp",
"opencode-deepdive",
Expand Down Expand Up @@ -150,6 +151,7 @@ override-dependencies = [
"transformers==5.6.2",
"torch>=2.9.0",
"openenv-core",
"verifiers[packages]>=0.1.15.dev150",
]

# ModelExpress 0.3.0 publishes protobuf<6 metadata, but its generated proto is
Expand Down Expand Up @@ -207,6 +209,7 @@ color-codeword = { path = "deps/research-environments/environments/color_codewor
deepdive = { path = "deps/research-environments/environments/deepdive", editable = true }
general-agent = { path = "deps/research-environments/environments/general_agent", editable = true }
gpqa = { path = "deps/research-environments/environments/gpqa", editable = true }
harnesses = { path = "deps/verifiers/packages/harnesses", editable = true }
hle = { path = "deps/research-environments/environments/hle", editable = true }
ifeval = { path = "deps/research-environments/environments/ifeval", editable = true }
livecodebench = { path = "deps/research-environments/environments/livecodebench", editable = true }
Expand All @@ -215,7 +218,6 @@ math-env = { path = "deps/research-environments/environments/math_env", editable
math-python = { path = "deps/verifiers/environments/math_python", editable = true }
math500 = { path = "deps/research-environments/environments/math500", editable = true }
mini-swe-agent-plus = { path = "deps/research-environments/environments/mini_swe_agent_plus", editable = true }
mini-swe-agent-plus-rlm = { path = "deps/research-environments/environments/mini_swe_agent_plus_rlm", editable = true }
mmlu-pro = { path = "deps/research-environments/environments/mmlu_pro", editable = true }
opencode-cp = { path = "deps/research-environments/environments/opencode_cp", editable = true }
opencode-deepdive = { path = "deps/research-environments/environments/opencode_deepdive", editable = true }
Expand All @@ -227,6 +229,7 @@ rlm-swe = { path = "deps/research-environments/environments/rlm_swe", editable =
science-env = { path = "deps/research-environments/environments/science_env", editable = true }
simpleqa-verified = { path = "deps/research-environments/environments/simpleqa_verified", editable = true }
tau2-bench = { path = "deps/research-environments/environments/tau2_bench", editable = true }
tasksets = { path = "deps/verifiers/packages/tasksets", editable = true }
wiki-search = { path = "deps/verifiers/environments/wiki_search", editable = true }
wordle = { path = "deps/verifiers/environments/wordle", editable = true }
torch = { index = "pytorch-cu128" }
Expand Down
47 changes: 37 additions & 10 deletions scripts/docker-arm64-post-install.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
#!/bin/bash
# arm64 post-install fixups for Docker builds.
set -e
# arm64 post-install fixups: rebuild flash-attn from source for the target GPU.
#
# Why this exists: pyproject.toml sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE to keep
# `uv sync` fast; on x86_64 it pins a prebuilt wheel to fill in the binary, but no
# such wheel exists for aarch64. Without this script, `import flash_attn` fails on
# aarch64 with `ModuleNotFoundError: No module named 'flash_attn_2_cuda'`.
#
# Defaults preserve the existing Docker behavior (sm_100 / GB200). On a host with
# `nvidia-smi` available, the compute capability is auto-detected from the local
# GPU. Override via env vars if needed:
# TORCH_CUDA_ARCH_LIST e.g. 9.0 (Hopper), 10.0 (Blackwell)
# VENV_PATH path to the venv (default: $(pwd)/.venv)
# MAX_JOBS parallel nvcc jobs (default: 4)
set -euo pipefail

echo "=== building flash-attn from source (sm_100 / GB200) ==="
# Run from /tmp so uv doesn't read pyproject.toml's [tool.uv.extra-build-variables]
# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and prevents CUDA kernel compilation.
export TORCH_CUDA_ARCH_LIST="10.0"
export MAX_JOBS=4
if [ -z "${TORCH_CUDA_ARCH_LIST:-}" ]; then
# Try to detect from the local GPU. Tolerate any failure mode (binary missing,
# driver not loaded, Docker buildx without --gpus) and fall back to GB200.
TORCH_CUDA_ARCH_LIST="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1 | tr -d ' ' || true)"
: "${TORCH_CUDA_ARCH_LIST:=10.0}"
fi
export TORCH_CUDA_ARCH_LIST

VENV_PATH="${VENV_PATH:-$(pwd)/.venv}"
if [ ! -x "$VENV_PATH/bin/python" ]; then
echo "ERROR: no python at $VENV_PATH/bin/python. Run from the project root or set VENV_PATH." >&2
exit 1
fi

export MAX_JOBS="${MAX_JOBS:-4}"
export FLASH_ATTENTION_FORCE_BUILD=TRUE
export FLASH_ATTENTION_SKIP_CUDA_BUILD=FALSE
(cd /tmp && uv pip install --python /app/.venv/bin/python \
"flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache)

echo "=== building flash-attn from source (TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST, MAX_JOBS=$MAX_JOBS) ==="
echo " target venv: $VENV_PATH"
# Run from /tmp so uv ignores the project's [tool.uv.extra-build-variables],
# which sets FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and would prevent kernel compilation.
(cd /tmp && uv pip install --python "$VENV_PATH/bin/python" \
"flash-attn==2.8.3" --no-build-isolation --no-binary flash-attn --no-cache --reinstall-package flash-attn)

echo "=== reinstalling flash-attn-cute (flash-attn overwrites it with a stub) ==="
uv pip install --reinstall --no-deps \
uv pip install --python "$VENV_PATH/bin/python" --reinstall --no-deps \
"flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@96bd151#subdirectory=flash_attn/cute"
28 changes: 11 additions & 17 deletions scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,11 @@ ensure_known_hosts() {
fi
}

# Initialize each submodule independently so that a missing private repo
# (e.g. configs/private when the user lacks access) does not abort the install.
init_submodules() {
if [ ! -f .gitmodules ]; then
return 0
fi
local paths failures
paths=$(git config -f .gitmodules --get-regexp '^submodule\..*\.path$' | awk '{print $2}')
failures=()
for path in $paths; do
log_info "Initializing submodule: ${path}"
if git submodule update --init --recursive -- "$path"; then
:
else
log_warn "Could not initialize submodule '${path}' (likely no access). Continuing without it."
failures+=("$path")
fi
done
if [ "${#failures[@]}" -gt 0 ]; then
log_warn "Skipped submodules: ${failures[*]}"
fi
git submodule update --init --recursive
}

main() {
Expand Down Expand Up @@ -159,6 +143,16 @@ main() {
log_info "Installing pre-commit hooks..."
uv run pre-commit install

# aarch64 has no prebuilt flash-attn wheel; build it from source for the local GPU.
# Without this, `import flash_attn` fails with `ModuleNotFoundError: flash_attn_2_cuda`.
# Run last so no subsequent uv operation (which implicitly syncs against the lockfile)
# rebuilds flash-attn from PyPI with FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE and undoes this.
if [ "$(uname -m)" = "aarch64" ]; then
log_info "aarch64 detected: building flash-attn from source (this takes 20-30 minutes)..."
log_warn "Future 'uv sync --all-extras' or 'uv run' will remove this build. Use 'uv sync --inexact' or 'uv run --no-sync' to keep it."
bash scripts/docker-arm64-post-install.sh
fi

log_info "Installation completed!"
}

Expand Down
1 change: 0 additions & 1 deletion skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,4 @@ Leave it unset for normal training. When enabled, it exports every sequence from

- `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli`
- `configs/debug/` — minimal debug configs
- `configs/private/` — private configs submodule (internal)
- `examples/` — full example configs
4 changes: 1 addition & 3 deletions skills/install/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ bash scripts/install.sh # clones, inits submodules, installs uv, runs `uv sync
For an existing clone, init submodules explicitly:

```bash
git submodule update --init -- deps/verifiers deps/renderers deps/research-environments deps/pydantic-config
git submodule update --init --recursive
```

Do **not** run `git submodule update --init --recursive` without paths — it tries to clone the private `configs/private` submodule and aborts for users without access. `scripts/install.sh` walks submodules one at a time and skips failures, so it works for everyone.

## Sync

```bash
Expand Down
Loading