Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions conf/finetune/gspo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- base
- _self_

attempts: 8
rl:
policy_loss: gspo
epsilon_high: 4e-4
epsilon_low: 3e-4
47 changes: 47 additions & 0 deletions conf/swe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defaults:
- base
- _self_
- override finetune: gspo

model_path: Qwen/Qwen3-8B

actor:
rollout_policy: pipelinerl.domains.swe.rollouts.generate_swe_rollout
success_threshold: 0.8

environments: null

dataset_loader: pipelinerl.domains.swe.load_datasets.load_local_swe_dataset
dataset_loader_params:
seed: ${seed}
# max_samples: 1000 # uncomment to cap the number of loaded samples (applies to both train and test)

# HuggingFace Hub dataset IDs (or local disk paths).
# Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train
train_dataset_names:
- SWE-bench/SWE-smith-py
test_dataset_names:
- SWE-bench/SWE-smith-py
Comment thread
ehsk marked this conversation as resolved.
Outdated

finetune:
seq_length: 24000
rl:
filter_zero_advantage_groups: false

vllm_config:
vllm_kwargs:
max_model_len: 24000

llm:
parameters:
max_tokens: 4096
temperature: 1.0
chat_template_kwargs:
enable_thinking: false

test_llm:
parameters:
max_tokens: 4096
temperature: 0.0
chat_template_kwargs:
enable_thinking: false
15 changes: 15 additions & 0 deletions conf/swe/preprocess.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Config for swe_preprocessor.py.
# Clones repos, extracts gold_file_contents at base_commit, applies token
# filtering, and saves a training-ready HuggingFace disk dataset.
#
# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess

hf_dataset_name: SWE-bench/SWE-smith-py
hf_split_name: train
repo_path: /path/to/repos
dataset_path: /path/to/output_ds
tokenizer_model: Qwen/Qwen3-8B
min_token_threshold: null # set to an int to filter out very short examples
max_token_threshold: 16000 # set to null to disable
num_map_processes: 32
force_reprocess: false
127 changes: 127 additions & 0 deletions pipelinerl/domains/swe/load_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Supported datasets
# ──────────────────────────────────────────────────────────────────────────────
# Ready to use (have gold_file_contents pre-extracted):
# SWE-bench/SWE-smith-py local preprocessed disk dataset or Hub ID
# SWE-bench/SWE-smith-java "
# SWE-bench/SWE-smith-rs "
# SWE-bench/SWE-smith-go "
#
# Require preprocessing first (clone repos, extract file contents at base_commit):
# princeton-nlp/SWE-bench
# princeton-nlp/SWE-bench_Lite
# princeton-nlp/SWE-bench_Verified
# SWE-bench/SWE-Pro (if/when released publicly)
#
# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess
# ──────────────────────────────────────────────────────────────────────────────

import json
import logging
import os
import random
from typing import Any, Dict, List, Optional

from datasets import load_dataset, load_from_disk

logger = logging.getLogger(__name__)


def _parse_file_contents(raw: Any) -> Dict[str, str]:
if isinstance(raw, dict):
return {str(k): str(v) for k, v in raw.items()}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except (json.JSONDecodeError, TypeError):
return {}
if isinstance(parsed, dict):
return {str(k): str(v) for k, v in parsed.items()}
return {}


def _load_single_dataset(path: str) -> List[Dict]:
"""Load a dataset from a local disk path or a HuggingFace Hub ID.

Local path: /path/to/ds_train
Hub ID: SWE-bench/SWE-smith-py (all splits concatenated)
Hub ID+split: SWE-bench/SWE-smith-py:train
"""
if os.path.exists(path):
logger.info("Loading from disk: %s", path)
dataset = load_from_disk(path)
else:
# Hub ID, optionally with ":split" suffix
if ":" in path:
hub_id, split = path.rsplit(":", 1)
else:
hub_id, split = path, None

logger.info("Loading from HuggingFace Hub: %s (split=%s)", hub_id, split or "all")
loaded = load_dataset(hub_id, split=split)

if split is None:
# DatasetDict — concatenate all splits
from datasets import concatenate_datasets
dataset = concatenate_datasets(list(loaded.values()))
else:
dataset = loaded

logger.info("Loaded %d rows from %s", len(dataset), path)

samples = []
for row in dataset:
item = dict(row)
try:
file_contents = _parse_file_contents(item.get("gold_file_contents", "{}"))
if not file_contents:
continue
samples.append({
"id": item.get("id", "") or item.get("instance_id", "") or item.get("issue_id", ""),
"dataset": item.get("dataset", "") or path,
"repo": item.get("repo", ""),
"base_commit": item.get("base_commit", ""),
"problem_statement": item.get("problem_statement", ""),
"patch": item.get("patch", ""),
"file_contents": file_contents,
})
except Exception as e:
logger.warning("Skipping malformed item: %s", e)

return samples


def load_local_swe_dataset(
dataset_paths: List[str],
seed: int = 42,
max_samples: Optional[int] = None,
) -> List[Dict]:
"""
Load one or more SWE-style datasets from disk and return a combined, shuffled list.

Args:
dataset_paths: Passed via cfg.train_dataset_names / cfg.test_dataset_names.
Each entry is a filesystem path to a HuggingFace disk dataset.
Add multiple paths to mix datasets (e.g. swe-smith + swe-bench).
seed: Random seed for shuffling (inherit from cfg.seed via dataset_loader_params).
max_samples: Optional cap on the total number of returned samples.
"""
if not dataset_paths:
logger.error("No dataset paths provided")
return []

all_samples: List[Dict] = []
for path in dataset_paths:
try:
all_samples.extend(_load_single_dataset(path))
except Exception as e:
logger.error("Failed to load dataset from %s: %s", path, e, exc_info=True)

random.Random(seed).shuffle(all_samples)
logger.info("Shuffled %d samples (seed=%d)", len(all_samples), seed)

if max_samples and len(all_samples) > max_samples:
all_samples = all_samples[:max_samples]
logger.info("Trimmed to max_samples=%d", max_samples)

logger.info("Returning %d samples total", len(all_samples))
return all_samples
129 changes: 129 additions & 0 deletions pipelinerl/domains/swe/repair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging
from typing import Dict, List

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = "You are a helpful coding assistant that analyzes code and fixes bugs."

USER_PROMPT_TEMPLATE = (
"Analyze the following code to find and fix bugs. Use this format:\n\n"
"<think>\n"
"[Your analysis process - be as detailed as you want until you're confident in your solution]\n"
"</think>\n\n"
"<solution>\n"
"[Your SEARCH/REPLACE edits using this format:]\n\n"
"### filename.py\n"
"<<<<<<< SEARCH\n"
"[exact code to find]\n"
"=======\n"
"[replacement code]\n"
">>>>>>> REPLACE\n"
"</solution>\n\n"
"IMPORTANT REQUIREMENTS:\n"
"- Every SEARCH/REPLACE edit must use the exact format above\n"
"- The SEARCH block must contain a contiguous chunk of lines that exist in the source code\n"
"- PROPER INDENTATION IS CRITICAL - if you want to add ' print(x)', you must include all those spaces\n"
"- Wrap each SEARCH/REPLACE edit in a code block\n"
"- Use separate code blocks for multiple edits\n\n"
"Example:\n"
"```python\n"
"### mathweb/flask/app.py\n"
"<<<<<<< SEARCH\n"
"from flask import Flask\n"
"=======\n"
"import math\n"
"from flask import Flask\n"
">>>>>>> REPLACE\n"
"```\n\n"
"Here is the issue:\n"
"--- BEGIN ISSUE ---\n"
"{problem_statement}\n"
"--- END ISSUE ---\n\n"
"Below are the code files that may contain bugs:\n"
"{file_contents}"
)


def build_messages(problem_statement: str, file_contents: Dict[str, str]) -> List[dict]:
"""Build the chat messages for a single-turn repair prompt."""
formatted_files = "".join(
f"### {path}\n```\n{content}\n```\n\n"
for path, content in file_contents.items()
)
user_content = USER_PROMPT_TEMPLATE.format(
problem_statement=problem_statement,
file_contents=formatted_files,
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]


def parse_edits(completion: str) -> List[dict]:
"""
Parse SEARCH/REPLACE blocks from a model completion.

Each code block must start with '### filepath' and contain exactly one
<<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple.
Returns a list of {'file_path', 'search', 'replace'} dicts.
"""
edits = []
code_blocks = _extract_code_blocks(completion)

for block in code_blocks:
edit = _parse_single_block(block)
if edit is not None:
edits.append(edit)

return edits


def _extract_code_blocks(text: str) -> List[str]:
blocks = []
in_block = False
current: List[str] = []
for line in text.split('\n'):
if line.strip().startswith('```'):
if in_block:
blocks.append('\n'.join(current))
current = []
in_block = not in_block
elif in_block:
current.append(line)
return blocks


def _parse_single_block(block: str) -> dict | None:
lines = block.split('\n')

file_path = None
start_idx = 0
for i, line in enumerate(lines):
if line.strip().startswith('###'):
file_path = line.strip()[3:].strip()
start_idx = i + 1
break

if not file_path:
return None

search_start = search_end = replace_start = replace_end = None
for i, line in enumerate(lines[start_idx:], start=start_idx):
if '<<<<<<< SEARCH' in line:
search_start = i + 1
elif '=======' in line and search_start is not None and search_end is None:
search_end = i
replace_start = i + 1
elif '>>>>>>> REPLACE' in line and replace_start is not None:
replace_end = i
break

if None in (search_start, search_end, replace_start, replace_end):
return None

return {
'file_path': file_path,
'search': '\n'.join(lines[search_start:search_end]),
'replace': '\n'.join(lines[replace_start:replace_end]),
}
Loading