diff --git a/conf/finetune/gspo.yaml b/conf/finetune/gspo.yaml
new file mode 100644
index 00000000..5b270d04
--- /dev/null
+++ b/conf/finetune/gspo.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - base
+ - _self_
+
+attempts: 8
+rl:
+ policy_loss: gspo
+ epsilon_high: 4e-4
+ epsilon_low: 3e-4
\ No newline at end of file
diff --git a/conf/swe.yaml b/conf/swe.yaml
new file mode 100644
index 00000000..6f6ba48d
--- /dev/null
+++ b/conf/swe.yaml
@@ -0,0 +1,53 @@
+defaults:
+ - base
+ - _self_
+ - override finetune: gspo
+
+model_path: Qwen/Qwen3-8B
+
+actor:
+ rollout_policy: pipelinerl.domains.swe.rollouts.generate_swe_rollout
+ success_threshold: 0.8
+
+environments: null
+
+dataset_loader: pipelinerl.domains.swe.load_datasets.load_local_swe_dataset
+dataset_loader_params:
+ seed: ${seed}
+ # max_samples: 1000 # uncomment to cap the number of loaded samples (applies to both train and test)
+
+# HuggingFace Hub dataset IDs (or local disk paths).
+# Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train
+# NOTE: SWE-smith and SWE-bench_Verified do not ship with gold_file_contents —
+# run the preprocessor once per dataset before training, e.g.:
+# python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess \
+# hf_dataset_name=SWE-bench/SWE-bench_Verified hf_split_name=test \
+# dataset_path=/your/output/path repo_path=/your/repos/cache
+# then point the entry below at the resulting dataset_path on disk.
+train_dataset_names:
+ - SWE-bench/SWE-smith
+test_dataset_names:
+ - SWE-bench/SWE-bench_Verified
+
+finetune:
+ seq_length: 24000
+ rl:
+ filter_zero_advantage_groups: false
+
+vllm_config:
+ vllm_kwargs:
+ max_model_len: 24000
+
+llm:
+ parameters:
+ max_tokens: 4096
+ temperature: 1.0
+ chat_template_kwargs:
+ enable_thinking: false
+
+test_llm:
+ parameters:
+ max_tokens: 4096
+ temperature: 0.0
+ chat_template_kwargs:
+ enable_thinking: false
diff --git a/conf/swe/preprocess.yaml b/conf/swe/preprocess.yaml
new file mode 100644
index 00000000..841a79b2
--- /dev/null
+++ b/conf/swe/preprocess.yaml
@@ -0,0 +1,15 @@
+# Config for swe_preprocessor.py.
+# Clones repos, extracts gold_file_contents at base_commit, applies token
+# filtering, and saves a training-ready HuggingFace disk dataset.
+#
+# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess
+
+hf_dataset_name: SWE-bench/SWE-smith-py
+hf_split_name: train
+repo_path: /path/to/repos
+dataset_path: /path/to/output_ds
+tokenizer_model: Qwen/Qwen3-8B
+min_token_threshold: null # set to an int to filter out very short examples
+max_token_threshold: 16000 # set to null to disable
+num_map_processes: 32
+force_reprocess: false
diff --git a/pipelinerl/domains/swe/load_datasets.py b/pipelinerl/domains/swe/load_datasets.py
new file mode 100644
index 00000000..e24bfe3f
--- /dev/null
+++ b/pipelinerl/domains/swe/load_datasets.py
@@ -0,0 +1,127 @@
+# Supported datasets
+# ──────────────────────────────────────────────────────────────────────────────
+# Ready to use (have gold_file_contents pre-extracted):
+# SWE-bench/SWE-smith-py local preprocessed disk dataset or Hub ID
+# SWE-bench/SWE-smith-java "
+# SWE-bench/SWE-smith-rs "
+# SWE-bench/SWE-smith-go "
+#
+# Require preprocessing first (clone repos, extract file contents at base_commit):
+# princeton-nlp/SWE-bench
+# princeton-nlp/SWE-bench_Lite
+# princeton-nlp/SWE-bench_Verified
+# SWE-bench/SWE-Pro (if/when released publicly)
+#
+# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess
+# ──────────────────────────────────────────────────────────────────────────────
+
+import json
+import logging
+import os
+import random
+from typing import Any, Dict, List, Optional
+
+from datasets import load_dataset, load_from_disk
+
+logger = logging.getLogger(__name__)
+
+
+def _parse_file_contents(raw: Any) -> Dict[str, str]:
+ if isinstance(raw, dict):
+ return {str(k): str(v) for k, v in raw.items()}
+ if isinstance(raw, str):
+ try:
+ parsed = json.loads(raw)
+ except (json.JSONDecodeError, TypeError):
+ return {}
+ if isinstance(parsed, dict):
+ return {str(k): str(v) for k, v in parsed.items()}
+ return {}
+
+
+def _load_single_dataset(path: str) -> List[Dict]:
+ """Load a dataset from a local disk path or a HuggingFace Hub ID.
+
+ Local path: /path/to/ds_train
+ Hub ID: SWE-bench/SWE-smith-py (all splits concatenated)
+ Hub ID+split: SWE-bench/SWE-smith-py:train
+ """
+ if os.path.exists(path):
+ logger.info("Loading from disk: %s", path)
+ dataset = load_from_disk(path)
+ else:
+ # Hub ID, optionally with ":split" suffix
+ if ":" in path:
+ hub_id, split = path.rsplit(":", 1)
+ else:
+ hub_id, split = path, None
+
+ logger.info("Loading from HuggingFace Hub: %s (split=%s)", hub_id, split or "all")
+ loaded = load_dataset(hub_id, split=split)
+
+ if split is None:
+ # DatasetDict — concatenate all splits
+ from datasets import concatenate_datasets
+ dataset = concatenate_datasets(list(loaded.values()))
+ else:
+ dataset = loaded
+
+ logger.info("Loaded %d rows from %s", len(dataset), path)
+
+ samples = []
+ for row in dataset:
+ item = dict(row)
+ try:
+ file_contents = _parse_file_contents(item.get("gold_file_contents", "{}"))
+ if not file_contents:
+ continue
+ samples.append({
+ "id": item.get("id", "") or item.get("instance_id", "") or item.get("issue_id", ""),
+ "dataset": item.get("dataset", "") or path,
+ "repo": item.get("repo", ""),
+ "base_commit": item.get("base_commit", ""),
+ "problem_statement": item.get("problem_statement", ""),
+ "patch": item.get("patch", ""),
+ "file_contents": file_contents,
+ })
+ except Exception as e:
+ logger.warning("Skipping malformed item: %s", e)
+
+ return samples
+
+
+def load_local_swe_dataset(
+ dataset_paths: List[str],
+ seed: int = 42,
+ max_samples: Optional[int] = None,
+) -> List[Dict]:
+ """
+ Load one or more SWE-style datasets from disk and return a combined, shuffled list.
+
+ Args:
+ dataset_paths: Passed via cfg.train_dataset_names / cfg.test_dataset_names.
+ Each entry is a filesystem path to a HuggingFace disk dataset.
+ Add multiple paths to mix datasets (e.g. swe-smith + swe-bench).
+ seed: Random seed for shuffling (inherit from cfg.seed via dataset_loader_params).
+ max_samples: Optional cap on the total number of returned samples.
+ """
+ if not dataset_paths:
+ logger.error("No dataset paths provided")
+ return []
+
+ all_samples: List[Dict] = []
+ for path in dataset_paths:
+ try:
+ all_samples.extend(_load_single_dataset(path))
+ except Exception as e:
+ logger.error("Failed to load dataset from %s: %s", path, e, exc_info=True)
+
+ random.Random(seed).shuffle(all_samples)
+ logger.info("Shuffled %d samples (seed=%d)", len(all_samples), seed)
+
+ if max_samples and len(all_samples) > max_samples:
+ all_samples = all_samples[:max_samples]
+ logger.info("Trimmed to max_samples=%d", max_samples)
+
+ logger.info("Returning %d samples total", len(all_samples))
+ return all_samples
diff --git a/pipelinerl/domains/swe/repair.py b/pipelinerl/domains/swe/repair.py
new file mode 100644
index 00000000..688f89f2
--- /dev/null
+++ b/pipelinerl/domains/swe/repair.py
@@ -0,0 +1,116 @@
+import logging
+from typing import Dict, List
+
+logger = logging.getLogger(__name__)
+
+SYSTEM_PROMPT = "You are a helpful coding assistant that analyzes code and fixes bugs."
+
+USER_PROMPT_TEMPLATE = (
+ "Analyze the following code to find and fix bugs. Use this format:\n\n"
+ "\n"
+ "[Your analysis process - be as detailed as you want until you're confident in your solution]\n"
+ "\n\n"
+ "\n"
+ "[Your SEARCH/REPLACE edits using this format:]\n\n"
+ "```\n"
+ "### filename.py\n"
+ "<<<<<<< SEARCH\n"
+ "[exact code to find]\n"
+ "=======\n"
+ "[replacement code]\n"
+ ">>>>>>> REPLACE\n"
+ "```\n"
+ "\n\n"
+ "IMPORTANT REQUIREMENTS:\n"
+ "- Every SEARCH/REPLACE edit must use the exact format above\n"
+ "- The SEARCH block must contain a contiguous chunk of lines that exist in the source code\n"
+ "- PROPER INDENTATION IS CRITICAL - if you want to add ' print(x)', you must include all those spaces\n"
+ "- Wrap each SEARCH/REPLACE edit in a code block\n"
+ "- Use separate code blocks for multiple edits\n\n"
+ "Example:\n"
+ "```python\n"
+ "### mathweb/flask/app.py\n"
+ "<<<<<<< SEARCH\n"
+ "from flask import Flask\n"
+ "=======\n"
+ "import math\n"
+ "from flask import Flask\n"
+ ">>>>>>> REPLACE\n"
+ "```\n\n"
+ "Here is the issue:\n"
+ "--- BEGIN ISSUE ---\n"
+ "{problem_statement}\n"
+ "--- END ISSUE ---\n\n"
+ "Below are the code files that may contain bugs:\n"
+ "{file_contents}"
+)
+
+
+def build_messages(problem_statement: str, file_contents: Dict[str, str]) -> List[dict]:
+ """Build the chat messages for a single-turn repair prompt."""
+ formatted_files = "".join(
+ f"### {path}\n```\n{content}\n```\n\n"
+ for path, content in file_contents.items()
+ )
+ user_content = USER_PROMPT_TEMPLATE.format(
+ problem_statement=problem_statement,
+ file_contents=formatted_files,
+ )
+ return [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+
+def parse_edits(completion: str) -> List[dict]:
+ """
+ Parse SEARCH/REPLACE blocks from a model completion.
+
+ Each block is a '### filepath' line followed by a
+ <<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. Triple-backtick code
+ fences around the block are accepted but not required.
+ Returns a list of {'file_path', 'search', 'replace'} dicts.
+ """
+ edits: List[dict] = []
+ lines = completion.split('\n')
+ n = len(lines)
+ i = 0
+ while i < n:
+ if '<<<<<<< SEARCH' not in lines[i]:
+ i += 1
+ continue
+
+ # Walk back to the most recent '### filepath' line, but don't cross a
+ # previous '>>>>>>> REPLACE' marker (that path belongs to an earlier edit).
+ file_path = None
+ for j in range(i - 1, -1, -1):
+ if '>>>>>>> REPLACE' in lines[j]:
+ break
+ stripped = lines[j].strip()
+ if stripped.startswith('###'):
+ file_path = stripped[3:].strip()
+ break
+ if not file_path:
+ i += 1
+ continue
+
+ search_start = i + 1
+ sep = replace_end = None
+ for k in range(search_start, n):
+ if sep is None and '=======' in lines[k]:
+ sep = k
+ elif sep is not None and '>>>>>>> REPLACE' in lines[k]:
+ replace_end = k
+ break
+
+ if sep is None or replace_end is None:
+ i += 1
+ continue
+
+ edits.append({
+ 'file_path': file_path,
+ 'search': '\n'.join(lines[search_start:sep]),
+ 'replace': '\n'.join(lines[sep + 1:replace_end]),
+ })
+ i = replace_end + 1
+ return edits
diff --git a/pipelinerl/domains/swe/reward.py b/pipelinerl/domains/swe/reward.py
new file mode 100644
index 00000000..22cc611b
--- /dev/null
+++ b/pipelinerl/domains/swe/reward.py
@@ -0,0 +1,163 @@
+import difflib
+import logging
+import re
+from typing import Dict, List, Tuple, TypedDict
+
+from unidiff import PatchSet
+from unidiff.errors import UnidiffParseError
+
+logger = logging.getLogger(__name__)
+
+
+class FormatError(Exception):
+ pass
+
+
+class ChangeSimilarity(TypedDict):
+ path: str
+ pred_change: str
+ oracle_change: str
+ similarity: float
+
+
+def parse_patch_for_gold_files(patch_text: str) -> List[str]:
+ """Extract modified file paths from a unified diff patch."""
+ if not patch_text:
+ return []
+ return re.findall(r'^--- a/(.+)$', patch_text, re.MULTILINE)
+
+
+def generate_unified_diff(old_code: str, new_code: str, n_context: int = 3) -> str:
+ diff = difflib.unified_diff(
+ old_code.splitlines(),
+ new_code.splitlines(),
+ fromfile="old",
+ tofile="new",
+ lineterm="",
+ n=n_context,
+ )
+ try:
+ next(diff)
+ next(diff)
+ return "\n".join(diff)
+ except StopIteration:
+ return ""
+
+
+def apply_edits_to_files(
+ file_contents: Dict[str, str],
+ edits: List[Dict],
+ silent: bool = False,
+) -> Dict[str, str]:
+ new_content_dict = dict(file_contents)
+ for edit in edits:
+ file_path = edit.get('file_path', '')
+ search_text = edit.get('search', '')
+ replace_text = edit.get('replace', '')
+
+ if not silent and search_text == replace_text:
+ raise FormatError("Search and replace blocks are identical")
+
+ if file_path not in new_content_dict:
+ if not silent:
+ raise FormatError(f"File {file_path} not found in file_contents")
+ logger.warning("File %s not found in file_contents", file_path)
+ continue
+
+ current_content = new_content_dict[file_path]
+ if search_text not in current_content:
+ if not silent:
+ raise FormatError(f"Search text not found in {file_path}: {search_text}")
+ logger.warning("Search text not found in %s", file_path)
+ continue
+
+ new_content_dict[file_path] = current_content.replace(search_text, replace_text, 1)
+
+ return new_content_dict
+
+
+def get_normalized_patch(
+ code_context: Dict[str, str],
+ new_content_dict: Dict[str, str],
+) -> Dict[str, str]:
+ patch_dict = {}
+ for path, new_content in new_content_dict.items():
+ old_content = code_context.get(path, "")
+ patch = generate_unified_diff(old_content, new_content)
+ if patch:
+ patch_dict[path] = patch
+ return patch_dict
+
+
+def get_filelevel_diff(patch_text: str) -> Dict[str, str]:
+ try:
+ patch = PatchSet(patch_text)
+ except UnidiffParseError:
+ return {}
+ except Exception as e:
+ logger.warning("Unexpected unidiff parsing error: %s", e)
+ return {}
+
+ result = {}
+ for patchfile in patch:
+ body = "\n".join(str(hunk).strip() for hunk in patchfile)
+ result[patchfile.path] = body.strip()
+ return result
+
+
+def compute_change_similarities(
+ pred_patch: Dict[str, str],
+ oracle_patch: Dict[str, str],
+) -> List[ChangeSimilarity]:
+ all_file_paths = set(oracle_patch) | set(pred_patch)
+ similarities = []
+ for path in all_file_paths:
+ pred_change = pred_patch.get(path, "")
+ oracle_change = oracle_patch.get(path, "")
+ if not oracle_change or not pred_change:
+ change_similarity = 0.0
+ else:
+ change_similarity = difflib.SequenceMatcher(
+ None, pred_change, oracle_change, autojunk=False
+ ).ratio()
+ similarities.append(ChangeSimilarity(
+ path=path,
+ pred_change=pred_change,
+ oracle_change=oracle_change,
+ similarity=change_similarity,
+ ))
+ return similarities
+
+
+def calculate_precise_reward(
+ file_contents: Dict[str, str],
+ oracle_patch_text: str,
+ predicted_edits: List[Dict],
+) -> Tuple[float, Dict]:
+ try:
+ if not predicted_edits:
+ raise FormatError("No valid search blocks found")
+
+ oracle_patch = get_filelevel_diff(oracle_patch_text)
+ pred_new_content = apply_edits_to_files(file_contents, predicted_edits)
+ pred_patch = get_normalized_patch(file_contents, pred_new_content)
+ similarities = compute_change_similarities(pred_patch, oracle_patch)
+
+ if not similarities:
+ assert not oracle_patch and not pred_patch
+ return 1.0, {"similarities": []}
+
+ reward = sum(s["similarity"] for s in similarities) / len(similarities)
+ return reward, {
+ "similarities": similarities,
+ "num_files_changed": len(similarities),
+ "oracle_files": list(oracle_patch.keys()),
+ "predicted_files": list(pred_patch.keys()),
+ }
+
+ except FormatError as e:
+ # logger.warning("Format error in reward calculation: %s", str(e))
+ return 0.0, {"format_error": True, "error_message": str(e)}
+ except Exception as e:
+ logger.error("Unexpected error in reward calculation: %s", e)
+ return 0.0, {"error": str(e)}
diff --git a/pipelinerl/domains/swe/rollouts.py b/pipelinerl/domains/swe/rollouts.py
new file mode 100644
index 00000000..c724266c
--- /dev/null
+++ b/pipelinerl/domains/swe/rollouts.py
@@ -0,0 +1,71 @@
+import logging
+import math
+import time
+
+import aiohttp
+from omegaconf import DictConfig
+
+from pipelinerl.async_llm import llm_async_generate, make_training_text
+from pipelinerl.llm import Prompt, TrainableLLM
+from pipelinerl.rollouts import BaseMetrics, RolloutResult
+
+from pipelinerl.domains.swe.repair import build_messages, parse_edits
+from pipelinerl.domains.swe.reward import calculate_precise_reward
+
+logger = logging.getLogger(__name__)
+
+
+class SWEMetrics(BaseMetrics):
+ format_error: bool = False
+
+
+async def generate_swe_rollout(
+ cfg: DictConfig,
+ llm: TrainableLLM,
+ problem: dict,
+ session: aiohttp.ClientSession,
+) -> RolloutResult:
+ file_contents = problem.get("file_contents", {})
+ problem_statement = problem["problem_statement"]
+ gold_patch = problem.get("patch", "")
+
+ messages = build_messages(problem_statement, file_contents)
+ prompt = Prompt(messages=messages)
+
+ time_start = time.time()
+ llm_call = await llm_async_generate(llm, prompt, session)
+ latency = time.time() - time_start
+
+ raw_output = llm_call.output.content or ""
+ edits = parse_edits(raw_output)
+
+ reward, reward_meta = calculate_precise_reward(file_contents, gold_patch, edits)
+
+ if hasattr(cfg.actor, 'discount_factor'):
+ reward *= cfg.actor.discount_factor ** llm_call.output_length_tokens
+
+ trace = make_training_text(llm, llm_call)
+ trace.reward = reward if not math.isnan(reward) else 0.0
+ trace.metadata["stage"] = "repair"
+ trace.metadata["dataset"] = problem.get("dataset")
+ trace.metadata["problem_id"] = problem.get("id") or problem.get("instance_id")
+
+ format_error = bool(reward_meta.get("format_error", False))
+ success_threshold = getattr(cfg.actor, 'success_threshold', 0.8)
+ success = (not format_error) and reward >= success_threshold
+
+ metrics = SWEMetrics(
+ reward=trace.reward,
+ success=success,
+ no_error=not format_error,
+ no_answer=len(edits) == 0,
+ format_error=format_error,
+ )
+
+ return RolloutResult(
+ training_texts=[trace],
+ metrics=metrics,
+ latency=latency,
+ dataset_name=problem.get("dataset"),
+ domain="swe",
+ )
diff --git a/pipelinerl/domains/swe/swe_preprocessor.py b/pipelinerl/domains/swe/swe_preprocessor.py
new file mode 100644
index 00000000..253dcf38
--- /dev/null
+++ b/pipelinerl/domains/swe/swe_preprocessor.py
@@ -0,0 +1,287 @@
+"""
+SWE dataset preprocessing utility.
+
+Clones repos, extracts gold_file_contents at base_commit, applies token-count
+filtering, and optionally computes per-repo file stats for BM25 retrieval.
+Run this once per dataset before training; output is saved as a HuggingFace disk dataset.
+
+Usage:
+ python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess
+"""
+
+import json
+import logging
+import math
+import os
+import re
+import shutil
+from collections import Counter
+from pathlib import Path
+from typing import Dict, List, Optional, Set
+
+import git
+import hydra
+from datasets import Dataset, load_dataset, load_from_disk
+from omegaconf import DictConfig
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+logger = logging.getLogger(__name__)
+
+
+class RepoManager:
+ """Clones or updates GitHub repositories to a local directory."""
+
+ def __init__(self, repos_base_dir: str):
+ self.repos_base_dir = Path(repos_base_dir)
+ os.makedirs(self.repos_base_dir, exist_ok=True)
+
+ def clone_or_update_repo(self, repo_name: str) -> Path:
+ repo_url = f"https://github.com/{repo_name}.git"
+ local_path = self.repos_base_dir / repo_name.replace("/", "_")
+
+ try:
+ if local_path.exists() and (local_path / ".git").exists():
+ repo = git.Repo(local_path)
+ repo.remotes.origin.fetch()
+ else:
+ if local_path.exists():
+ logger.warning("Removing broken directory %s before fresh clone", local_path)
+ shutil.rmtree(local_path)
+ git.Repo.clone_from(repo_url, local_path)
+ return local_path
+ except Exception as e:
+ logger.error("Failed to process repo %s: %s", repo_name, e)
+ raise
+
+ def clone_or_update_repos(self, repo_names: List[str]) -> Dict[str, Path]:
+ results: Dict[str, Path] = {}
+ failed: List[str] = []
+ for name in tqdm(repo_names, desc="Cloning/updating repos"):
+ try:
+ path = self.clone_or_update_repo(name)
+ if (path / ".git").exists():
+ results[name] = path
+ else:
+ failed.append(name)
+ except Exception:
+ failed.append(name)
+ if failed:
+ logger.warning("Failed to clone/update %d repos: %s", len(set(failed)), list(set(failed)))
+ return results
+
+
+class SwePreprocessor:
+ """Processes a SWE-style HuggingFace dataset into a training-ready disk dataset."""
+
+ SOURCE_EXTENSIONS = {
+ ".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", ".cs", ".php",
+ ".rb", ".go", ".rs", ".kt", ".scala", ".swift", ".m", ".mm", ".sh", ".bash",
+ ".zsh", ".fish", ".pl", ".r", ".R", ".sql", ".html", ".css", ".scss", ".sass",
+ ".less", ".vue", ".jsx", ".tsx", ".json", ".yaml", ".yml", ".xml", ".toml",
+ ".ini", ".cfg", ".conf", ".properties", ".gradle", ".cmake", ".make",
+ ".dockerfile", ".md", ".rst", ".txt", ".lock", ".requirements", ".pyx", ".ipynb",
+ ".pxd", ".pyi", ".pxi.in",
+ }
+ SKIP_DIRS = {
+ "test", "tests", "__pycache__", ".git", ".svn", ".hg", "node_modules",
+ ".pytest_cache", ".tox", "venv", ".env", "dist", ".idea", ".vscode",
+ "target", "out", "bin", "obj", ".gradle", "coverage", ".coverage",
+ ".nyc_output", "htmlcov",
+ }
+
+ def __init__(self, cfg: DictConfig):
+ self.cfg = cfg
+ self.repos_base_dir = Path(cfg.repo_path)
+ self.dataset_path = Path(cfg.dataset_path)
+ self.min_token_threshold = cfg.min_token_threshold
+ self.max_token_threshold = cfg.max_token_threshold
+ self.num_map_processes = cfg.num_map_processes
+ self.tokenizer_model = cfg.tokenizer_model
+ self.repo_manager = RepoManager(self.repos_base_dir)
+ self.file_stats_cache: Dict = {}
+ self.tokenizer = self._init_tokenizer()
+
+ def _init_tokenizer(self):
+ try:
+ tok = AutoTokenizer.from_pretrained(self.tokenizer_model)
+ logger.info("Tokenizer loaded: %s", self.tokenizer_model)
+ return tok
+ except Exception as e:
+ logger.warning("Could not load tokenizer (%s): %s — token filtering disabled", self.tokenizer_model, e)
+ return None
+
+ # ── File helpers ─────────────────────────────────────────────────────────
+
+ def _is_source_file(self, filepath: str) -> bool:
+ path = Path(filepath)
+ if any(p.lower() in self.SKIP_DIRS for p in path.parts):
+ return False
+ if path.suffix.lower() in self.SOURCE_EXTENSIONS:
+ return True
+ return not path.suffix and path.name.lower() in {
+ "makefile", "dockerfile", "readme", "license", "changelog",
+ "requirements", "pipfile", "gemfile", "rakefile",
+ }
+
+ def _get_file_content(self, repo_path: Path, commit: str, filepath: str) -> Optional[str]:
+ if not (repo_path / ".git").exists():
+ return None
+ try:
+ return git.Repo(repo_path).git.show(f"{commit}:{filepath}")
+ except Exception:
+ return None
+
+ def _parse_patch(self, patch: str) -> List[str]:
+ return re.findall(r"^--- a/(.+)$", patch or "", re.MULTILINE)
+
+ # ── Token filtering ───────────────────────────────────────────────────────
+
+ def _filter_by_token_count(self, example: Dict) -> bool:
+ if not self.tokenizer:
+ return True
+ try:
+ contents = json.loads(example.get("gold_file_contents", "{}"))
+ if not contents:
+ return False
+ text = " ".join(contents.values()) + " " + (example.get("problem_statement") or "")
+ n = len(self.tokenizer.encode(text, add_special_tokens=False))
+ if self.min_token_threshold is not None and n < self.min_token_threshold:
+ return False
+ if self.max_token_threshold is not None and n > self.max_token_threshold:
+ return False
+ return True
+ except Exception as e:
+ logger.error("Token filter error: %s", e)
+ return False
+
+ # ── File stats (for BM25, optional) ──────────────────────────────────────
+
+ def _tokenize_content(self, content: str) -> Counter:
+ tokens = re.findall(r"[a-zA-Z0-9_]+", content.lower())
+ return Counter(t for t in tokens if 2 <= len(t) <= 50 and not (t.isdigit() and len(t) > 4))
+
+ def _get_all_file_stats(self, repo_path: Path, commit: str) -> Dict[str, Dict]:
+ key = (str(repo_path), commit)
+ if key in self.file_stats_cache:
+ return self.file_stats_cache[key]
+
+ stats: Dict[str, Dict] = {}
+ if not (repo_path / ".git").exists():
+ self.file_stats_cache[key] = stats
+ return stats
+
+ try:
+ repo = git.Repo(repo_path)
+ files = repo.git.execute(["git", "ls-tree", "-r", "--name-only", commit])
+ for filepath in (files.strip().split("\n") if files.strip() else []):
+ if not filepath or not self._is_source_file(filepath):
+ continue
+ try:
+ content = repo.git.show(f"{commit}:{filepath}")
+ content.encode("utf-8") # skip binary
+ term_counts = self._tokenize_content(content)
+ stats[filepath] = {"path": filepath, "length": len(content), "term_counts": dict(term_counts)}
+ except Exception:
+ continue
+ except Exception as e:
+ logger.error("File stats error for %s@%s: %s", repo_path, commit, e)
+
+ self.file_stats_cache[key] = stats
+ return stats
+
+ # ── Per-example processing passes ────────────────────────────────────────
+
+ def _extract_gold_file_contents_only(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict:
+ repo = example.get("repo")
+ commit = example.get("base_commit")
+ patch = example.get("patch")
+ contents: Dict[str, str] = {}
+
+ repo_path = repo_paths.get(repo) if repo else None
+ if repo_path and commit and patch:
+ for fp in self._parse_patch(patch):
+ c = self._get_file_content(repo_path, commit, fp)
+ if c is not None:
+ contents[fp] = c
+
+ example["gold_file_contents"] = json.dumps(contents)
+ return example
+
+ def _add_file_stats_to_example(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict:
+ example["all_file_stats"] = json.dumps({})
+ example["_invalid_example"] = True
+
+ repo = example.get("repo")
+ commit = example.get("base_commit")
+ patch = example.get("patch")
+ repo_path = repo_paths.get(repo) if repo else None
+
+ if not (repo_path and commit and patch):
+ return example
+
+ gold_files = self._parse_patch(patch)
+ if any(not self._is_source_file(f) for f in gold_files):
+ return example
+
+ all_stats = self._get_all_file_stats(repo_path, commit)
+ example["all_file_stats"] = json.dumps(all_stats)
+ example["_invalid_example"] = any(f not in all_stats for f in gold_files)
+ return example
+
+ # ── Main entry point ─────────────────────────────────────────────────────
+
+ def process(self) -> Set[str]:
+ if self.dataset_path.exists() and not self.cfg.force_reprocess:
+ logger.info("Found existing dataset at %s, loading from disk", self.dataset_path)
+ try:
+ ds = load_from_disk(str(self.dataset_path))
+ return set(ds["repo"])
+ except Exception as e:
+ logger.warning("Could not load existing dataset (%s), reprocessing", e)
+
+ logger.info("Loading %s (split=%s) from Hub", self.cfg.hf_dataset_name, self.cfg.hf_split_name)
+ dataset = load_dataset(self.cfg.hf_dataset_name, split=self.cfg.hf_split_name)
+ logger.info("Loaded %d examples", len(dataset))
+
+ unique_repos = set(dataset["repo"])
+ repo_paths = self.repo_manager.clone_or_update_repos(list(unique_repos))
+
+ # Pass 1: extract gold file contents (fast)
+ dataset = dataset.map(
+ lambda ex: self._extract_gold_file_contents_only(ex, repo_paths),
+ batched=False, num_proc=self.num_map_processes,
+ load_from_cache_file=False, desc="Extracting gold file contents",
+ )
+
+ # Token-count filtering (before expensive file stats)
+ if self.tokenizer:
+ before = len(dataset)
+ dataset = dataset.filter(self._filter_by_token_count, desc="Token-count filtering")
+ logger.info("Token filtering: %d → %d examples", before, len(dataset))
+
+ # Pass 2: compute file stats (expensive, only for surviving examples)
+ dataset = dataset.map(
+ lambda ex: self._add_file_stats_to_example(ex, repo_paths),
+ batched=False, num_proc=self.num_map_processes,
+ load_from_cache_file=False, desc="Computing file stats",
+ )
+
+ before = len(dataset)
+ dataset = dataset.filter(lambda ex: not ex.get("_invalid_example", True), desc="Filtering invalid examples")
+ dataset = dataset.remove_columns(["_invalid_example"])
+ logger.info("Validity filtering: %d → %d examples", before, len(dataset))
+
+ logger.info("Saving to %s", self.dataset_path)
+ dataset.save_to_disk(str(self.dataset_path))
+ logger.info("Done. %d examples saved.", len(dataset))
+ return unique_repos
+
+
+@hydra.main(config_path="../../../../conf", config_name="swe/preprocess", version_base="1.3.2")
+def main(cfg: DictConfig):
+ SwePreprocessor(cfg).process()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py
index 2ba28c5e..e61dc730 100644
--- a/pipelinerl/finetune/rl/__init__.py
+++ b/pipelinerl/finetune/rl/__init__.py
@@ -186,7 +186,6 @@ def rl_step(
model_inputs = {
"input_ids": batch.input_ids,
"attention_mask": batch.attention_mask,
- "labels": batch.labels,
}
if batch.is_packed:
model_inputs["position_ids"] = batch.position_ids
diff --git a/pipelinerl/finetune/value_model.py b/pipelinerl/finetune/value_model.py
index 7c8f1136..a14214f7 100644
--- a/pipelinerl/finetune/value_model.py
+++ b/pipelinerl/finetune/value_model.py
@@ -15,8 +15,6 @@ class CausalLMOutputWithValue(ModelOutput):
Output type for causal language models with an additional value head.
Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
- Language modeling loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`):
Prediction scores of the language modeling head.
value (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
@@ -29,7 +27,6 @@ class CausalLMOutputWithValue(ModelOutput):
Attention weights after the attention softmax.
"""
- loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
value: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -76,7 +73,6 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@@ -93,7 +89,6 @@ def forward(
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
- labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
@@ -107,7 +102,6 @@ def forward(
values = self.value_head(hidden_states)
return CausalLMOutputWithValue(
- loss=outputs.loss,
logits=outputs.logits,
value=values,
past_key_values=outputs.past_key_values,
@@ -120,7 +114,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.pretrained_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs
)
-
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -132,14 +126,14 @@ def save_pretrained(
):
"""Save model and value head separately."""
import os
-
+
if state_dict is None:
state_dict = self.state_dict()
-
+
# Extract pretrained model and value head state dicts
pretrained_model_state_dict = {}
value_head_state_dict = {}
-
+
for key, value in state_dict.items():
if key.startswith("value_head."):
# Remove the "value_head." prefix
@@ -154,7 +148,7 @@ def save_pretrained(
f"Unexpected key in state dict: {key}. "
"Expected keys should start with 'value_head.' or 'pretrained_model.'."
)
-
+
# Save the pretrained model which can be easily loaded by vllm, etc.
self.pretrained_model.save_pretrained(
save_directory,
@@ -164,7 +158,7 @@ def save_pretrained(
safe_serialization=safe_serialization,
**kwargs,
)
-
+
# Save value head separately
if is_main_process:
value_head_path = os.path.join(save_directory, "value_head.pt")
diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py
index 632b760e..bd49ba04 100644
--- a/pipelinerl/streams.py
+++ b/pipelinerl/streams.py
@@ -110,7 +110,7 @@ def connect_to_redis(config: RedisConfig):
logger.debug(f"Trying to connect to Redis server at {config.host}:{config.port}")
client = redis.Redis(host=config.host, port=config.port)
client.ping()
- logger.info(f"Connected to Redis server")
+ logger.info("Connected to Redis server")
return client
except (redis.exceptions.TimeoutError, redis.ConnectionError) as e:
logger.warning(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.")
diff --git a/pyproject.toml b/pyproject.toml
index 81ecbd42..568fc65f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,6 +64,9 @@ coding = [
fn_calling = [
"bfcl-eval>=2025.6.8",
]
+swe = [
+ "unidiff>=0.7.5",
+]
logic = [
# i3-logic verification code is vendored in pipelinerl/domains/logic/i3_logic/
# (source: https://github.com/PrimeIntellect/i3-logic)