From 355b8637ffbbcf2955a6bbd9bbbc3461db12ab12 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 22 Apr 2026 16:45:56 +0000 Subject: [PATCH 1/5] initial commit for adding swe tasks as a new domain --- conf/finetune/gspo.yaml | 9 + conf/swe.yaml | 47 ++++ conf/swe/preprocess.yaml | 15 ++ pipelinerl/domains/swe/load_datasets.py | 127 +++++++++ pipelinerl/domains/swe/repair.py | 129 +++++++++ pipelinerl/domains/swe/reward.py | 163 ++++++++++++ pipelinerl/domains/swe/rollouts.py | 71 +++++ pipelinerl/domains/swe/swe_preprocessor.py | 287 +++++++++++++++++++++ pyproject.toml | 3 + 9 files changed, 851 insertions(+) create mode 100644 conf/finetune/gspo.yaml create mode 100644 conf/swe.yaml create mode 100644 conf/swe/preprocess.yaml create mode 100644 pipelinerl/domains/swe/load_datasets.py create mode 100644 pipelinerl/domains/swe/repair.py create mode 100644 pipelinerl/domains/swe/reward.py create mode 100644 pipelinerl/domains/swe/rollouts.py create mode 100644 pipelinerl/domains/swe/swe_preprocessor.py diff --git a/conf/finetune/gspo.yaml b/conf/finetune/gspo.yaml new file mode 100644 index 00000000..5b270d04 --- /dev/null +++ b/conf/finetune/gspo.yaml @@ -0,0 +1,9 @@ +defaults: + - base + - _self_ + +attempts: 8 +rl: + policy_loss: gspo + epsilon_high: 4e-4 + epsilon_low: 3e-4 \ No newline at end of file diff --git a/conf/swe.yaml b/conf/swe.yaml new file mode 100644 index 00000000..f2a748ed --- /dev/null +++ b/conf/swe.yaml @@ -0,0 +1,47 @@ +defaults: + - base + - _self_ + - override finetune: gspo + +model_path: Qwen/Qwen3-8B + +actor: + rollout_policy: pipelinerl.domains.swe.rollouts.generate_swe_rollout + success_threshold: 0.8 + +environments: null + +dataset_loader: pipelinerl.domains.swe.load_datasets.load_local_swe_dataset +dataset_loader_params: + seed: ${seed} + # max_samples: 1000 # uncomment to cap the number of loaded samples (applies to both train and test) + +# HuggingFace Hub dataset IDs (or local disk paths). +# Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train +train_dataset_names: + - SWE-bench/SWE-smith-py +test_dataset_names: + - SWE-bench/SWE-smith-py + +finetune: + seq_length: 24000 + rl: + filter_zero_advantage_groups: false + +vllm_config: + vllm_kwargs: + max_model_len: 24000 + +llm: + parameters: + max_tokens: 4096 + temperature: 1.0 + chat_template_kwargs: + enable_thinking: false + +test_llm: + parameters: + max_tokens: 4096 + temperature: 0.0 + chat_template_kwargs: + enable_thinking: false diff --git a/conf/swe/preprocess.yaml b/conf/swe/preprocess.yaml new file mode 100644 index 00000000..841a79b2 --- /dev/null +++ b/conf/swe/preprocess.yaml @@ -0,0 +1,15 @@ +# Config for swe_preprocessor.py. +# Clones repos, extracts gold_file_contents at base_commit, applies token +# filtering, and saves a training-ready HuggingFace disk dataset. +# +# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess + +hf_dataset_name: SWE-bench/SWE-smith-py +hf_split_name: train +repo_path: /path/to/repos +dataset_path: /path/to/output_ds +tokenizer_model: Qwen/Qwen3-8B +min_token_threshold: null # set to an int to filter out very short examples +max_token_threshold: 16000 # set to null to disable +num_map_processes: 32 +force_reprocess: false diff --git a/pipelinerl/domains/swe/load_datasets.py b/pipelinerl/domains/swe/load_datasets.py new file mode 100644 index 00000000..e24bfe3f --- /dev/null +++ b/pipelinerl/domains/swe/load_datasets.py @@ -0,0 +1,127 @@ +# Supported datasets +# ────────────────────────────────────────────────────────────────────────────── +# Ready to use (have gold_file_contents pre-extracted): +# SWE-bench/SWE-smith-py local preprocessed disk dataset or Hub ID +# SWE-bench/SWE-smith-java " +# SWE-bench/SWE-smith-rs " +# SWE-bench/SWE-smith-go " +# +# Require preprocessing first (clone repos, extract file contents at base_commit): +# princeton-nlp/SWE-bench +# princeton-nlp/SWE-bench_Lite +# princeton-nlp/SWE-bench_Verified +# SWE-bench/SWE-Pro (if/when released publicly) +# +# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess +# ────────────────────────────────────────────────────────────────────────────── + +import json +import logging +import os +import random +from typing import Any, Dict, List, Optional + +from datasets import load_dataset, load_from_disk + +logger = logging.getLogger(__name__) + + +def _parse_file_contents(raw: Any) -> Dict[str, str]: + if isinstance(raw, dict): + return {str(k): str(v) for k, v in raw.items()} + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return {} + if isinstance(parsed, dict): + return {str(k): str(v) for k, v in parsed.items()} + return {} + + +def _load_single_dataset(path: str) -> List[Dict]: + """Load a dataset from a local disk path or a HuggingFace Hub ID. + + Local path: /path/to/ds_train + Hub ID: SWE-bench/SWE-smith-py (all splits concatenated) + Hub ID+split: SWE-bench/SWE-smith-py:train + """ + if os.path.exists(path): + logger.info("Loading from disk: %s", path) + dataset = load_from_disk(path) + else: + # Hub ID, optionally with ":split" suffix + if ":" in path: + hub_id, split = path.rsplit(":", 1) + else: + hub_id, split = path, None + + logger.info("Loading from HuggingFace Hub: %s (split=%s)", hub_id, split or "all") + loaded = load_dataset(hub_id, split=split) + + if split is None: + # DatasetDict — concatenate all splits + from datasets import concatenate_datasets + dataset = concatenate_datasets(list(loaded.values())) + else: + dataset = loaded + + logger.info("Loaded %d rows from %s", len(dataset), path) + + samples = [] + for row in dataset: + item = dict(row) + try: + file_contents = _parse_file_contents(item.get("gold_file_contents", "{}")) + if not file_contents: + continue + samples.append({ + "id": item.get("id", "") or item.get("instance_id", "") or item.get("issue_id", ""), + "dataset": item.get("dataset", "") or path, + "repo": item.get("repo", ""), + "base_commit": item.get("base_commit", ""), + "problem_statement": item.get("problem_statement", ""), + "patch": item.get("patch", ""), + "file_contents": file_contents, + }) + except Exception as e: + logger.warning("Skipping malformed item: %s", e) + + return samples + + +def load_local_swe_dataset( + dataset_paths: List[str], + seed: int = 42, + max_samples: Optional[int] = None, +) -> List[Dict]: + """ + Load one or more SWE-style datasets from disk and return a combined, shuffled list. + + Args: + dataset_paths: Passed via cfg.train_dataset_names / cfg.test_dataset_names. + Each entry is a filesystem path to a HuggingFace disk dataset. + Add multiple paths to mix datasets (e.g. swe-smith + swe-bench). + seed: Random seed for shuffling (inherit from cfg.seed via dataset_loader_params). + max_samples: Optional cap on the total number of returned samples. + """ + if not dataset_paths: + logger.error("No dataset paths provided") + return [] + + all_samples: List[Dict] = [] + for path in dataset_paths: + try: + all_samples.extend(_load_single_dataset(path)) + except Exception as e: + logger.error("Failed to load dataset from %s: %s", path, e, exc_info=True) + + random.Random(seed).shuffle(all_samples) + logger.info("Shuffled %d samples (seed=%d)", len(all_samples), seed) + + if max_samples and len(all_samples) > max_samples: + all_samples = all_samples[:max_samples] + logger.info("Trimmed to max_samples=%d", max_samples) + + logger.info("Returning %d samples total", len(all_samples)) + return all_samples diff --git a/pipelinerl/domains/swe/repair.py b/pipelinerl/domains/swe/repair.py new file mode 100644 index 00000000..be045ff5 --- /dev/null +++ b/pipelinerl/domains/swe/repair.py @@ -0,0 +1,129 @@ +import logging +from typing import Dict, List + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT = "You are a helpful coding assistant that analyzes code and fixes bugs." + +USER_PROMPT_TEMPLATE = ( + "Analyze the following code to find and fix bugs. Use this format:\n\n" + "\n" + "[Your analysis process - be as detailed as you want until you're confident in your solution]\n" + "\n\n" + "\n" + "[Your SEARCH/REPLACE edits using this format:]\n\n" + "### filename.py\n" + "<<<<<<< SEARCH\n" + "[exact code to find]\n" + "=======\n" + "[replacement code]\n" + ">>>>>>> REPLACE\n" + "\n\n" + "IMPORTANT REQUIREMENTS:\n" + "- Every SEARCH/REPLACE edit must use the exact format above\n" + "- The SEARCH block must contain a contiguous chunk of lines that exist in the source code\n" + "- PROPER INDENTATION IS CRITICAL - if you want to add ' print(x)', you must include all those spaces\n" + "- Wrap each SEARCH/REPLACE edit in a code block\n" + "- Use separate code blocks for multiple edits\n\n" + "Example:\n" + "```python\n" + "### mathweb/flask/app.py\n" + "<<<<<<< SEARCH\n" + "from flask import Flask\n" + "=======\n" + "import math\n" + "from flask import Flask\n" + ">>>>>>> REPLACE\n" + "```\n\n" + "Here is the issue:\n" + "--- BEGIN ISSUE ---\n" + "{problem_statement}\n" + "--- END ISSUE ---\n\n" + "Below are the code files that may contain bugs:\n" + "{file_contents}" +) + + +def build_messages(problem_statement: str, file_contents: Dict[str, str]) -> List[dict]: + """Build the chat messages for a single-turn repair prompt.""" + formatted_files = "".join( + f"### {path}\n```\n{content}\n```\n\n" + for path, content in file_contents.items() + ) + user_content = USER_PROMPT_TEMPLATE.format( + problem_statement=problem_statement, + file_contents=formatted_files, + ) + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + +def parse_edits(completion: str) -> List[dict]: + """ + Parse SEARCH/REPLACE blocks from a model completion. + + Each code block must start with '### filepath' and contain exactly one + <<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. + Returns a list of {'file_path', 'search', 'replace'} dicts. + """ + edits = [] + code_blocks = _extract_code_blocks(completion) + + for block in code_blocks: + edit = _parse_single_block(block) + if edit is not None: + edits.append(edit) + + return edits + + +def _extract_code_blocks(text: str) -> List[str]: + blocks = [] + in_block = False + current: List[str] = [] + for line in text.split('\n'): + if line.strip().startswith('```'): + if in_block: + blocks.append('\n'.join(current)) + current = [] + in_block = not in_block + elif in_block: + current.append(line) + return blocks + + +def _parse_single_block(block: str) -> dict | None: + lines = block.split('\n') + + file_path = None + start_idx = 0 + for i, line in enumerate(lines): + if line.strip().startswith('###'): + file_path = line.strip()[3:].strip() + start_idx = i + 1 + break + + if not file_path: + return None + + search_start = search_end = replace_start = replace_end = None + for i, line in enumerate(lines[start_idx:], start=start_idx): + if '<<<<<<< SEARCH' in line: + search_start = i + 1 + elif '=======' in line and search_start is not None and search_end is None: + search_end = i + replace_start = i + 1 + elif '>>>>>>> REPLACE' in line and replace_start is not None: + replace_end = i + break + + if None in (search_start, search_end, replace_start, replace_end): + return None + + return { + 'file_path': file_path, + 'search': '\n'.join(lines[search_start:search_end]), + 'replace': '\n'.join(lines[replace_start:replace_end]), + } diff --git a/pipelinerl/domains/swe/reward.py b/pipelinerl/domains/swe/reward.py new file mode 100644 index 00000000..22cc611b --- /dev/null +++ b/pipelinerl/domains/swe/reward.py @@ -0,0 +1,163 @@ +import difflib +import logging +import re +from typing import Dict, List, Tuple, TypedDict + +from unidiff import PatchSet +from unidiff.errors import UnidiffParseError + +logger = logging.getLogger(__name__) + + +class FormatError(Exception): + pass + + +class ChangeSimilarity(TypedDict): + path: str + pred_change: str + oracle_change: str + similarity: float + + +def parse_patch_for_gold_files(patch_text: str) -> List[str]: + """Extract modified file paths from a unified diff patch.""" + if not patch_text: + return [] + return re.findall(r'^--- a/(.+)$', patch_text, re.MULTILINE) + + +def generate_unified_diff(old_code: str, new_code: str, n_context: int = 3) -> str: + diff = difflib.unified_diff( + old_code.splitlines(), + new_code.splitlines(), + fromfile="old", + tofile="new", + lineterm="", + n=n_context, + ) + try: + next(diff) + next(diff) + return "\n".join(diff) + except StopIteration: + return "" + + +def apply_edits_to_files( + file_contents: Dict[str, str], + edits: List[Dict], + silent: bool = False, +) -> Dict[str, str]: + new_content_dict = dict(file_contents) + for edit in edits: + file_path = edit.get('file_path', '') + search_text = edit.get('search', '') + replace_text = edit.get('replace', '') + + if not silent and search_text == replace_text: + raise FormatError("Search and replace blocks are identical") + + if file_path not in new_content_dict: + if not silent: + raise FormatError(f"File {file_path} not found in file_contents") + logger.warning("File %s not found in file_contents", file_path) + continue + + current_content = new_content_dict[file_path] + if search_text not in current_content: + if not silent: + raise FormatError(f"Search text not found in {file_path}: {search_text}") + logger.warning("Search text not found in %s", file_path) + continue + + new_content_dict[file_path] = current_content.replace(search_text, replace_text, 1) + + return new_content_dict + + +def get_normalized_patch( + code_context: Dict[str, str], + new_content_dict: Dict[str, str], +) -> Dict[str, str]: + patch_dict = {} + for path, new_content in new_content_dict.items(): + old_content = code_context.get(path, "") + patch = generate_unified_diff(old_content, new_content) + if patch: + patch_dict[path] = patch + return patch_dict + + +def get_filelevel_diff(patch_text: str) -> Dict[str, str]: + try: + patch = PatchSet(patch_text) + except UnidiffParseError: + return {} + except Exception as e: + logger.warning("Unexpected unidiff parsing error: %s", e) + return {} + + result = {} + for patchfile in patch: + body = "\n".join(str(hunk).strip() for hunk in patchfile) + result[patchfile.path] = body.strip() + return result + + +def compute_change_similarities( + pred_patch: Dict[str, str], + oracle_patch: Dict[str, str], +) -> List[ChangeSimilarity]: + all_file_paths = set(oracle_patch) | set(pred_patch) + similarities = [] + for path in all_file_paths: + pred_change = pred_patch.get(path, "") + oracle_change = oracle_patch.get(path, "") + if not oracle_change or not pred_change: + change_similarity = 0.0 + else: + change_similarity = difflib.SequenceMatcher( + None, pred_change, oracle_change, autojunk=False + ).ratio() + similarities.append(ChangeSimilarity( + path=path, + pred_change=pred_change, + oracle_change=oracle_change, + similarity=change_similarity, + )) + return similarities + + +def calculate_precise_reward( + file_contents: Dict[str, str], + oracle_patch_text: str, + predicted_edits: List[Dict], +) -> Tuple[float, Dict]: + try: + if not predicted_edits: + raise FormatError("No valid search blocks found") + + oracle_patch = get_filelevel_diff(oracle_patch_text) + pred_new_content = apply_edits_to_files(file_contents, predicted_edits) + pred_patch = get_normalized_patch(file_contents, pred_new_content) + similarities = compute_change_similarities(pred_patch, oracle_patch) + + if not similarities: + assert not oracle_patch and not pred_patch + return 1.0, {"similarities": []} + + reward = sum(s["similarity"] for s in similarities) / len(similarities) + return reward, { + "similarities": similarities, + "num_files_changed": len(similarities), + "oracle_files": list(oracle_patch.keys()), + "predicted_files": list(pred_patch.keys()), + } + + except FormatError as e: + # logger.warning("Format error in reward calculation: %s", str(e)) + return 0.0, {"format_error": True, "error_message": str(e)} + except Exception as e: + logger.error("Unexpected error in reward calculation: %s", e) + return 0.0, {"error": str(e)} diff --git a/pipelinerl/domains/swe/rollouts.py b/pipelinerl/domains/swe/rollouts.py new file mode 100644 index 00000000..c724266c --- /dev/null +++ b/pipelinerl/domains/swe/rollouts.py @@ -0,0 +1,71 @@ +import logging +import math +import time + +import aiohttp +from omegaconf import DictConfig + +from pipelinerl.async_llm import llm_async_generate, make_training_text +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.rollouts import BaseMetrics, RolloutResult + +from pipelinerl.domains.swe.repair import build_messages, parse_edits +from pipelinerl.domains.swe.reward import calculate_precise_reward + +logger = logging.getLogger(__name__) + + +class SWEMetrics(BaseMetrics): + format_error: bool = False + + +async def generate_swe_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + file_contents = problem.get("file_contents", {}) + problem_statement = problem["problem_statement"] + gold_patch = problem.get("patch", "") + + messages = build_messages(problem_statement, file_contents) + prompt = Prompt(messages=messages) + + time_start = time.time() + llm_call = await llm_async_generate(llm, prompt, session) + latency = time.time() - time_start + + raw_output = llm_call.output.content or "" + edits = parse_edits(raw_output) + + reward, reward_meta = calculate_precise_reward(file_contents, gold_patch, edits) + + if hasattr(cfg.actor, 'discount_factor'): + reward *= cfg.actor.discount_factor ** llm_call.output_length_tokens + + trace = make_training_text(llm, llm_call) + trace.reward = reward if not math.isnan(reward) else 0.0 + trace.metadata["stage"] = "repair" + trace.metadata["dataset"] = problem.get("dataset") + trace.metadata["problem_id"] = problem.get("id") or problem.get("instance_id") + + format_error = bool(reward_meta.get("format_error", False)) + success_threshold = getattr(cfg.actor, 'success_threshold', 0.8) + success = (not format_error) and reward >= success_threshold + + metrics = SWEMetrics( + reward=trace.reward, + success=success, + no_error=not format_error, + no_answer=len(edits) == 0, + format_error=format_error, + ) + + return RolloutResult( + training_texts=[trace], + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + domain="swe", + ) diff --git a/pipelinerl/domains/swe/swe_preprocessor.py b/pipelinerl/domains/swe/swe_preprocessor.py new file mode 100644 index 00000000..253dcf38 --- /dev/null +++ b/pipelinerl/domains/swe/swe_preprocessor.py @@ -0,0 +1,287 @@ +""" +SWE dataset preprocessing utility. + +Clones repos, extracts gold_file_contents at base_commit, applies token-count +filtering, and optionally computes per-repo file stats for BM25 retrieval. +Run this once per dataset before training; output is saved as a HuggingFace disk dataset. + +Usage: + python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess +""" + +import json +import logging +import math +import os +import re +import shutil +from collections import Counter +from pathlib import Path +from typing import Dict, List, Optional, Set + +import git +import hydra +from datasets import Dataset, load_dataset, load_from_disk +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + + +class RepoManager: + """Clones or updates GitHub repositories to a local directory.""" + + def __init__(self, repos_base_dir: str): + self.repos_base_dir = Path(repos_base_dir) + os.makedirs(self.repos_base_dir, exist_ok=True) + + def clone_or_update_repo(self, repo_name: str) -> Path: + repo_url = f"https://github.com/{repo_name}.git" + local_path = self.repos_base_dir / repo_name.replace("/", "_") + + try: + if local_path.exists() and (local_path / ".git").exists(): + repo = git.Repo(local_path) + repo.remotes.origin.fetch() + else: + if local_path.exists(): + logger.warning("Removing broken directory %s before fresh clone", local_path) + shutil.rmtree(local_path) + git.Repo.clone_from(repo_url, local_path) + return local_path + except Exception as e: + logger.error("Failed to process repo %s: %s", repo_name, e) + raise + + def clone_or_update_repos(self, repo_names: List[str]) -> Dict[str, Path]: + results: Dict[str, Path] = {} + failed: List[str] = [] + for name in tqdm(repo_names, desc="Cloning/updating repos"): + try: + path = self.clone_or_update_repo(name) + if (path / ".git").exists(): + results[name] = path + else: + failed.append(name) + except Exception: + failed.append(name) + if failed: + logger.warning("Failed to clone/update %d repos: %s", len(set(failed)), list(set(failed))) + return results + + +class SwePreprocessor: + """Processes a SWE-style HuggingFace dataset into a training-ready disk dataset.""" + + SOURCE_EXTENSIONS = { + ".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", ".cs", ".php", + ".rb", ".go", ".rs", ".kt", ".scala", ".swift", ".m", ".mm", ".sh", ".bash", + ".zsh", ".fish", ".pl", ".r", ".R", ".sql", ".html", ".css", ".scss", ".sass", + ".less", ".vue", ".jsx", ".tsx", ".json", ".yaml", ".yml", ".xml", ".toml", + ".ini", ".cfg", ".conf", ".properties", ".gradle", ".cmake", ".make", + ".dockerfile", ".md", ".rst", ".txt", ".lock", ".requirements", ".pyx", ".ipynb", + ".pxd", ".pyi", ".pxi.in", + } + SKIP_DIRS = { + "test", "tests", "__pycache__", ".git", ".svn", ".hg", "node_modules", + ".pytest_cache", ".tox", "venv", ".env", "dist", ".idea", ".vscode", + "target", "out", "bin", "obj", ".gradle", "coverage", ".coverage", + ".nyc_output", "htmlcov", + } + + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.repos_base_dir = Path(cfg.repo_path) + self.dataset_path = Path(cfg.dataset_path) + self.min_token_threshold = cfg.min_token_threshold + self.max_token_threshold = cfg.max_token_threshold + self.num_map_processes = cfg.num_map_processes + self.tokenizer_model = cfg.tokenizer_model + self.repo_manager = RepoManager(self.repos_base_dir) + self.file_stats_cache: Dict = {} + self.tokenizer = self._init_tokenizer() + + def _init_tokenizer(self): + try: + tok = AutoTokenizer.from_pretrained(self.tokenizer_model) + logger.info("Tokenizer loaded: %s", self.tokenizer_model) + return tok + except Exception as e: + logger.warning("Could not load tokenizer (%s): %s — token filtering disabled", self.tokenizer_model, e) + return None + + # ── File helpers ───────────────────────────────────────────────────────── + + def _is_source_file(self, filepath: str) -> bool: + path = Path(filepath) + if any(p.lower() in self.SKIP_DIRS for p in path.parts): + return False + if path.suffix.lower() in self.SOURCE_EXTENSIONS: + return True + return not path.suffix and path.name.lower() in { + "makefile", "dockerfile", "readme", "license", "changelog", + "requirements", "pipfile", "gemfile", "rakefile", + } + + def _get_file_content(self, repo_path: Path, commit: str, filepath: str) -> Optional[str]: + if not (repo_path / ".git").exists(): + return None + try: + return git.Repo(repo_path).git.show(f"{commit}:{filepath}") + except Exception: + return None + + def _parse_patch(self, patch: str) -> List[str]: + return re.findall(r"^--- a/(.+)$", patch or "", re.MULTILINE) + + # ── Token filtering ─────────────────────────────────────────────────────── + + def _filter_by_token_count(self, example: Dict) -> bool: + if not self.tokenizer: + return True + try: + contents = json.loads(example.get("gold_file_contents", "{}")) + if not contents: + return False + text = " ".join(contents.values()) + " " + (example.get("problem_statement") or "") + n = len(self.tokenizer.encode(text, add_special_tokens=False)) + if self.min_token_threshold is not None and n < self.min_token_threshold: + return False + if self.max_token_threshold is not None and n > self.max_token_threshold: + return False + return True + except Exception as e: + logger.error("Token filter error: %s", e) + return False + + # ── File stats (for BM25, optional) ────────────────────────────────────── + + def _tokenize_content(self, content: str) -> Counter: + tokens = re.findall(r"[a-zA-Z0-9_]+", content.lower()) + return Counter(t for t in tokens if 2 <= len(t) <= 50 and not (t.isdigit() and len(t) > 4)) + + def _get_all_file_stats(self, repo_path: Path, commit: str) -> Dict[str, Dict]: + key = (str(repo_path), commit) + if key in self.file_stats_cache: + return self.file_stats_cache[key] + + stats: Dict[str, Dict] = {} + if not (repo_path / ".git").exists(): + self.file_stats_cache[key] = stats + return stats + + try: + repo = git.Repo(repo_path) + files = repo.git.execute(["git", "ls-tree", "-r", "--name-only", commit]) + for filepath in (files.strip().split("\n") if files.strip() else []): + if not filepath or not self._is_source_file(filepath): + continue + try: + content = repo.git.show(f"{commit}:{filepath}") + content.encode("utf-8") # skip binary + term_counts = self._tokenize_content(content) + stats[filepath] = {"path": filepath, "length": len(content), "term_counts": dict(term_counts)} + except Exception: + continue + except Exception as e: + logger.error("File stats error for %s@%s: %s", repo_path, commit, e) + + self.file_stats_cache[key] = stats + return stats + + # ── Per-example processing passes ──────────────────────────────────────── + + def _extract_gold_file_contents_only(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict: + repo = example.get("repo") + commit = example.get("base_commit") + patch = example.get("patch") + contents: Dict[str, str] = {} + + repo_path = repo_paths.get(repo) if repo else None + if repo_path and commit and patch: + for fp in self._parse_patch(patch): + c = self._get_file_content(repo_path, commit, fp) + if c is not None: + contents[fp] = c + + example["gold_file_contents"] = json.dumps(contents) + return example + + def _add_file_stats_to_example(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict: + example["all_file_stats"] = json.dumps({}) + example["_invalid_example"] = True + + repo = example.get("repo") + commit = example.get("base_commit") + patch = example.get("patch") + repo_path = repo_paths.get(repo) if repo else None + + if not (repo_path and commit and patch): + return example + + gold_files = self._parse_patch(patch) + if any(not self._is_source_file(f) for f in gold_files): + return example + + all_stats = self._get_all_file_stats(repo_path, commit) + example["all_file_stats"] = json.dumps(all_stats) + example["_invalid_example"] = any(f not in all_stats for f in gold_files) + return example + + # ── Main entry point ───────────────────────────────────────────────────── + + def process(self) -> Set[str]: + if self.dataset_path.exists() and not self.cfg.force_reprocess: + logger.info("Found existing dataset at %s, loading from disk", self.dataset_path) + try: + ds = load_from_disk(str(self.dataset_path)) + return set(ds["repo"]) + except Exception as e: + logger.warning("Could not load existing dataset (%s), reprocessing", e) + + logger.info("Loading %s (split=%s) from Hub", self.cfg.hf_dataset_name, self.cfg.hf_split_name) + dataset = load_dataset(self.cfg.hf_dataset_name, split=self.cfg.hf_split_name) + logger.info("Loaded %d examples", len(dataset)) + + unique_repos = set(dataset["repo"]) + repo_paths = self.repo_manager.clone_or_update_repos(list(unique_repos)) + + # Pass 1: extract gold file contents (fast) + dataset = dataset.map( + lambda ex: self._extract_gold_file_contents_only(ex, repo_paths), + batched=False, num_proc=self.num_map_processes, + load_from_cache_file=False, desc="Extracting gold file contents", + ) + + # Token-count filtering (before expensive file stats) + if self.tokenizer: + before = len(dataset) + dataset = dataset.filter(self._filter_by_token_count, desc="Token-count filtering") + logger.info("Token filtering: %d → %d examples", before, len(dataset)) + + # Pass 2: compute file stats (expensive, only for surviving examples) + dataset = dataset.map( + lambda ex: self._add_file_stats_to_example(ex, repo_paths), + batched=False, num_proc=self.num_map_processes, + load_from_cache_file=False, desc="Computing file stats", + ) + + before = len(dataset) + dataset = dataset.filter(lambda ex: not ex.get("_invalid_example", True), desc="Filtering invalid examples") + dataset = dataset.remove_columns(["_invalid_example"]) + logger.info("Validity filtering: %d → %d examples", before, len(dataset)) + + logger.info("Saving to %s", self.dataset_path) + dataset.save_to_disk(str(self.dataset_path)) + logger.info("Done. %d examples saved.", len(dataset)) + return unique_repos + + +@hydra.main(config_path="../../../../conf", config_name="swe/preprocess", version_base="1.3.2") +def main(cfg: DictConfig): + SwePreprocessor(cfg).process() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5f646c3a..3702fe05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ coding = [ fn_calling = [ "bfcl-eval>=2025.6.8", ] +swe = [ + "unidiff>=0.7.5", +] logic = [ # i3-logic verification code is vendored in pipelinerl/domains/logic/i3_logic/ # (source: https://github.com/PrimeIntellect/i3-logic) From 997ebecc3920b40a5b635f6351df2ab37ca061a7 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 22 Apr 2026 16:46:35 +0000 Subject: [PATCH 2/5] loss never used from model output so better not to pass labels --- pipelinerl/finetune/rl/__init__.py | 1 - pipelinerl/finetune/value_model.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 075e1d1e..c140ef7c 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -186,7 +186,6 @@ def rl_step( model_inputs = { "input_ids": batch.input_ids, "attention_mask": batch.attention_mask, - "labels": batch.labels, } if batch.is_packed: model_inputs["position_ids"] = batch.position_ids diff --git a/pipelinerl/finetune/value_model.py b/pipelinerl/finetune/value_model.py index 7c8f1136..d258bfc3 100644 --- a/pipelinerl/finetune/value_model.py +++ b/pipelinerl/finetune/value_model.py @@ -93,7 +93,6 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, @@ -107,7 +106,6 @@ def forward( values = self.value_head(hidden_states) return CausalLMOutputWithValue( - loss=outputs.loss, logits=outputs.logits, value=values, past_key_values=outputs.past_key_values, From d35249ccbdbc90da3edf01caec842fb91fdf216f Mon Sep 17 00:00:00 2001 From: ehsk Date: Mon, 4 May 2026 13:27:05 +0000 Subject: [PATCH 3/5] edit parser improved to support "```" blocks as optional --- pipelinerl/domains/swe/repair.py | 107 ++++++++++++++----------------- pipelinerl/streams.py | 2 +- 2 files changed, 48 insertions(+), 61 deletions(-) diff --git a/pipelinerl/domains/swe/repair.py b/pipelinerl/domains/swe/repair.py index be045ff5..688f89f2 100644 --- a/pipelinerl/domains/swe/repair.py +++ b/pipelinerl/domains/swe/repair.py @@ -12,12 +12,14 @@ "\n\n" "\n" "[Your SEARCH/REPLACE edits using this format:]\n\n" + "```\n" "### filename.py\n" "<<<<<<< SEARCH\n" "[exact code to find]\n" "=======\n" "[replacement code]\n" ">>>>>>> REPLACE\n" + "```\n" "\n\n" "IMPORTANT REQUIREMENTS:\n" "- Every SEARCH/REPLACE edit must use the exact format above\n" @@ -64,66 +66,51 @@ def parse_edits(completion: str) -> List[dict]: """ Parse SEARCH/REPLACE blocks from a model completion. - Each code block must start with '### filepath' and contain exactly one - <<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. + Each block is a '### filepath' line followed by a + <<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. Triple-backtick code + fences around the block are accepted but not required. Returns a list of {'file_path', 'search', 'replace'} dicts. """ - edits = [] - code_blocks = _extract_code_blocks(completion) - - for block in code_blocks: - edit = _parse_single_block(block) - if edit is not None: - edits.append(edit) - + edits: List[dict] = [] + lines = completion.split('\n') + n = len(lines) + i = 0 + while i < n: + if '<<<<<<< SEARCH' not in lines[i]: + i += 1 + continue + + # Walk back to the most recent '### filepath' line, but don't cross a + # previous '>>>>>>> REPLACE' marker (that path belongs to an earlier edit). + file_path = None + for j in range(i - 1, -1, -1): + if '>>>>>>> REPLACE' in lines[j]: + break + stripped = lines[j].strip() + if stripped.startswith('###'): + file_path = stripped[3:].strip() + break + if not file_path: + i += 1 + continue + + search_start = i + 1 + sep = replace_end = None + for k in range(search_start, n): + if sep is None and '=======' in lines[k]: + sep = k + elif sep is not None and '>>>>>>> REPLACE' in lines[k]: + replace_end = k + break + + if sep is None or replace_end is None: + i += 1 + continue + + edits.append({ + 'file_path': file_path, + 'search': '\n'.join(lines[search_start:sep]), + 'replace': '\n'.join(lines[sep + 1:replace_end]), + }) + i = replace_end + 1 return edits - - -def _extract_code_blocks(text: str) -> List[str]: - blocks = [] - in_block = False - current: List[str] = [] - for line in text.split('\n'): - if line.strip().startswith('```'): - if in_block: - blocks.append('\n'.join(current)) - current = [] - in_block = not in_block - elif in_block: - current.append(line) - return blocks - - -def _parse_single_block(block: str) -> dict | None: - lines = block.split('\n') - - file_path = None - start_idx = 0 - for i, line in enumerate(lines): - if line.strip().startswith('###'): - file_path = line.strip()[3:].strip() - start_idx = i + 1 - break - - if not file_path: - return None - - search_start = search_end = replace_start = replace_end = None - for i, line in enumerate(lines[start_idx:], start=start_idx): - if '<<<<<<< SEARCH' in line: - search_start = i + 1 - elif '=======' in line and search_start is not None and search_end is None: - search_end = i - replace_start = i + 1 - elif '>>>>>>> REPLACE' in line and replace_start is not None: - replace_end = i - break - - if None in (search_start, search_end, replace_start, replace_end): - return None - - return { - 'file_path': file_path, - 'search': '\n'.join(lines[search_start:search_end]), - 'replace': '\n'.join(lines[replace_start:replace_end]), - } diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index 632b760e..bd49ba04 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -110,7 +110,7 @@ def connect_to_redis(config: RedisConfig): logger.debug(f"Trying to connect to Redis server at {config.host}:{config.port}") client = redis.Redis(host=config.host, port=config.port) client.ping() - logger.info(f"Connected to Redis server") + logger.info("Connected to Redis server") return client except (redis.exceptions.TimeoutError, redis.ConnectionError) as e: logger.warning(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.") From 3aa4103f2634def674c36ba4f73d075ede815148 Mon Sep 17 00:00:00 2001 From: ehsk Date: Tue, 12 May 2026 20:26:19 +0000 Subject: [PATCH 4/5] default datasets updated --- conf/swe.yaml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/conf/swe.yaml b/conf/swe.yaml index f2a748ed..6f6ba48d 100644 --- a/conf/swe.yaml +++ b/conf/swe.yaml @@ -18,10 +18,16 @@ dataset_loader_params: # HuggingFace Hub dataset IDs (or local disk paths). # Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train +# NOTE: SWE-smith and SWE-bench_Verified do not ship with gold_file_contents — +# run the preprocessor once per dataset before training, e.g.: +# python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess \ +# hf_dataset_name=SWE-bench/SWE-bench_Verified hf_split_name=test \ +# dataset_path=/your/output/path repo_path=/your/repos/cache +# then point the entry below at the resulting dataset_path on disk. train_dataset_names: - - SWE-bench/SWE-smith-py + - SWE-bench/SWE-smith test_dataset_names: - - SWE-bench/SWE-smith-py + - SWE-bench/SWE-bench_Verified finetune: seq_length: 24000 From 7f1e8a80d1591e7d46f8876c24fd0429cbbb84d4 Mon Sep 17 00:00:00 2001 From: ehsk Date: Wed, 13 May 2026 13:28:53 +0000 Subject: [PATCH 5/5] unused parameters removed --- pipelinerl/finetune/value_model.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pipelinerl/finetune/value_model.py b/pipelinerl/finetune/value_model.py index d258bfc3..a14214f7 100644 --- a/pipelinerl/finetune/value_model.py +++ b/pipelinerl/finetune/value_model.py @@ -15,8 +15,6 @@ class CausalLMOutputWithValue(ModelOutput): Output type for causal language models with an additional value head. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*): - Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): Prediction scores of the language modeling head. value (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -29,7 +27,6 @@ class CausalLMOutputWithValue(ModelOutput): Attention weights after the attention softmax. """ - loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None value: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -76,7 +73,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -118,7 +114,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.pretrained_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs ) - + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -130,14 +126,14 @@ def save_pretrained( ): """Save model and value head separately.""" import os - + if state_dict is None: state_dict = self.state_dict() - + # Extract pretrained model and value head state dicts pretrained_model_state_dict = {} value_head_state_dict = {} - + for key, value in state_dict.items(): if key.startswith("value_head."): # Remove the "value_head." prefix @@ -152,7 +148,7 @@ def save_pretrained( f"Unexpected key in state dict: {key}. " "Expected keys should start with 'value_head.' or 'pretrained_model.'." ) - + # Save the pretrained model which can be easily loaded by vllm, etc. self.pretrained_model.save_pretrained( save_directory, @@ -162,7 +158,7 @@ def save_pretrained( safe_serialization=safe_serialization, **kwargs, ) - + # Save value head separately if is_main_process: value_head_path = os.path.join(save_directory, "value_head.pt")