diff --git a/conf/finetune/gspo.yaml b/conf/finetune/gspo.yaml new file mode 100644 index 00000000..5b270d04 --- /dev/null +++ b/conf/finetune/gspo.yaml @@ -0,0 +1,9 @@ +defaults: + - base + - _self_ + +attempts: 8 +rl: + policy_loss: gspo + epsilon_high: 4e-4 + epsilon_low: 3e-4 \ No newline at end of file diff --git a/conf/swe.yaml b/conf/swe.yaml new file mode 100644 index 00000000..6f6ba48d --- /dev/null +++ b/conf/swe.yaml @@ -0,0 +1,53 @@ +defaults: + - base + - _self_ + - override finetune: gspo + +model_path: Qwen/Qwen3-8B + +actor: + rollout_policy: pipelinerl.domains.swe.rollouts.generate_swe_rollout + success_threshold: 0.8 + +environments: null + +dataset_loader: pipelinerl.domains.swe.load_datasets.load_local_swe_dataset +dataset_loader_params: + seed: ${seed} + # max_samples: 1000 # uncomment to cap the number of loaded samples (applies to both train and test) + +# HuggingFace Hub dataset IDs (or local disk paths). +# Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train +# NOTE: SWE-smith and SWE-bench_Verified do not ship with gold_file_contents — +# run the preprocessor once per dataset before training, e.g.: +# python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess \ +# hf_dataset_name=SWE-bench/SWE-bench_Verified hf_split_name=test \ +# dataset_path=/your/output/path repo_path=/your/repos/cache +# then point the entry below at the resulting dataset_path on disk. +train_dataset_names: + - SWE-bench/SWE-smith +test_dataset_names: + - SWE-bench/SWE-bench_Verified + +finetune: + seq_length: 24000 + rl: + filter_zero_advantage_groups: false + +vllm_config: + vllm_kwargs: + max_model_len: 24000 + +llm: + parameters: + max_tokens: 4096 + temperature: 1.0 + chat_template_kwargs: + enable_thinking: false + +test_llm: + parameters: + max_tokens: 4096 + temperature: 0.0 + chat_template_kwargs: + enable_thinking: false diff --git a/conf/swe/preprocess.yaml b/conf/swe/preprocess.yaml new file mode 100644 index 00000000..841a79b2 --- /dev/null +++ b/conf/swe/preprocess.yaml @@ -0,0 +1,15 @@ +# Config for swe_preprocessor.py. +# Clones repos, extracts gold_file_contents at base_commit, applies token +# filtering, and saves a training-ready HuggingFace disk dataset. +# +# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess + +hf_dataset_name: SWE-bench/SWE-smith-py +hf_split_name: train +repo_path: /path/to/repos +dataset_path: /path/to/output_ds +tokenizer_model: Qwen/Qwen3-8B +min_token_threshold: null # set to an int to filter out very short examples +max_token_threshold: 16000 # set to null to disable +num_map_processes: 32 +force_reprocess: false diff --git a/pipelinerl/domains/swe/load_datasets.py b/pipelinerl/domains/swe/load_datasets.py new file mode 100644 index 00000000..e24bfe3f --- /dev/null +++ b/pipelinerl/domains/swe/load_datasets.py @@ -0,0 +1,127 @@ +# Supported datasets +# ────────────────────────────────────────────────────────────────────────────── +# Ready to use (have gold_file_contents pre-extracted): +# SWE-bench/SWE-smith-py local preprocessed disk dataset or Hub ID +# SWE-bench/SWE-smith-java " +# SWE-bench/SWE-smith-rs " +# SWE-bench/SWE-smith-go " +# +# Require preprocessing first (clone repos, extract file contents at base_commit): +# princeton-nlp/SWE-bench +# princeton-nlp/SWE-bench_Lite +# princeton-nlp/SWE-bench_Verified +# SWE-bench/SWE-Pro (if/when released publicly) +# +# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess +# ────────────────────────────────────────────────────────────────────────────── + +import json +import logging +import os +import random +from typing import Any, Dict, List, Optional + +from datasets import load_dataset, load_from_disk + +logger = logging.getLogger(__name__) + + +def _parse_file_contents(raw: Any) -> Dict[str, str]: + if isinstance(raw, dict): + return {str(k): str(v) for k, v in raw.items()} + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return {} + if isinstance(parsed, dict): + return {str(k): str(v) for k, v in parsed.items()} + return {} + + +def _load_single_dataset(path: str) -> List[Dict]: + """Load a dataset from a local disk path or a HuggingFace Hub ID. + + Local path: /path/to/ds_train + Hub ID: SWE-bench/SWE-smith-py (all splits concatenated) + Hub ID+split: SWE-bench/SWE-smith-py:train + """ + if os.path.exists(path): + logger.info("Loading from disk: %s", path) + dataset = load_from_disk(path) + else: + # Hub ID, optionally with ":split" suffix + if ":" in path: + hub_id, split = path.rsplit(":", 1) + else: + hub_id, split = path, None + + logger.info("Loading from HuggingFace Hub: %s (split=%s)", hub_id, split or "all") + loaded = load_dataset(hub_id, split=split) + + if split is None: + # DatasetDict — concatenate all splits + from datasets import concatenate_datasets + dataset = concatenate_datasets(list(loaded.values())) + else: + dataset = loaded + + logger.info("Loaded %d rows from %s", len(dataset), path) + + samples = [] + for row in dataset: + item = dict(row) + try: + file_contents = _parse_file_contents(item.get("gold_file_contents", "{}")) + if not file_contents: + continue + samples.append({ + "id": item.get("id", "") or item.get("instance_id", "") or item.get("issue_id", ""), + "dataset": item.get("dataset", "") or path, + "repo": item.get("repo", ""), + "base_commit": item.get("base_commit", ""), + "problem_statement": item.get("problem_statement", ""), + "patch": item.get("patch", ""), + "file_contents": file_contents, + }) + except Exception as e: + logger.warning("Skipping malformed item: %s", e) + + return samples + + +def load_local_swe_dataset( + dataset_paths: List[str], + seed: int = 42, + max_samples: Optional[int] = None, +) -> List[Dict]: + """ + Load one or more SWE-style datasets from disk and return a combined, shuffled list. + + Args: + dataset_paths: Passed via cfg.train_dataset_names / cfg.test_dataset_names. + Each entry is a filesystem path to a HuggingFace disk dataset. + Add multiple paths to mix datasets (e.g. swe-smith + swe-bench). + seed: Random seed for shuffling (inherit from cfg.seed via dataset_loader_params). + max_samples: Optional cap on the total number of returned samples. + """ + if not dataset_paths: + logger.error("No dataset paths provided") + return [] + + all_samples: List[Dict] = [] + for path in dataset_paths: + try: + all_samples.extend(_load_single_dataset(path)) + except Exception as e: + logger.error("Failed to load dataset from %s: %s", path, e, exc_info=True) + + random.Random(seed).shuffle(all_samples) + logger.info("Shuffled %d samples (seed=%d)", len(all_samples), seed) + + if max_samples and len(all_samples) > max_samples: + all_samples = all_samples[:max_samples] + logger.info("Trimmed to max_samples=%d", max_samples) + + logger.info("Returning %d samples total", len(all_samples)) + return all_samples diff --git a/pipelinerl/domains/swe/repair.py b/pipelinerl/domains/swe/repair.py new file mode 100644 index 00000000..688f89f2 --- /dev/null +++ b/pipelinerl/domains/swe/repair.py @@ -0,0 +1,116 @@ +import logging +from typing import Dict, List + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT = "You are a helpful coding assistant that analyzes code and fixes bugs." + +USER_PROMPT_TEMPLATE = ( + "Analyze the following code to find and fix bugs. Use this format:\n\n" + "\n" + "[Your analysis process - be as detailed as you want until you're confident in your solution]\n" + "\n\n" + "\n" + "[Your SEARCH/REPLACE edits using this format:]\n\n" + "```\n" + "### filename.py\n" + "<<<<<<< SEARCH\n" + "[exact code to find]\n" + "=======\n" + "[replacement code]\n" + ">>>>>>> REPLACE\n" + "```\n" + "\n\n" + "IMPORTANT REQUIREMENTS:\n" + "- Every SEARCH/REPLACE edit must use the exact format above\n" + "- The SEARCH block must contain a contiguous chunk of lines that exist in the source code\n" + "- PROPER INDENTATION IS CRITICAL - if you want to add ' print(x)', you must include all those spaces\n" + "- Wrap each SEARCH/REPLACE edit in a code block\n" + "- Use separate code blocks for multiple edits\n\n" + "Example:\n" + "```python\n" + "### mathweb/flask/app.py\n" + "<<<<<<< SEARCH\n" + "from flask import Flask\n" + "=======\n" + "import math\n" + "from flask import Flask\n" + ">>>>>>> REPLACE\n" + "```\n\n" + "Here is the issue:\n" + "--- BEGIN ISSUE ---\n" + "{problem_statement}\n" + "--- END ISSUE ---\n\n" + "Below are the code files that may contain bugs:\n" + "{file_contents}" +) + + +def build_messages(problem_statement: str, file_contents: Dict[str, str]) -> List[dict]: + """Build the chat messages for a single-turn repair prompt.""" + formatted_files = "".join( + f"### {path}\n```\n{content}\n```\n\n" + for path, content in file_contents.items() + ) + user_content = USER_PROMPT_TEMPLATE.format( + problem_statement=problem_statement, + file_contents=formatted_files, + ) + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + +def parse_edits(completion: str) -> List[dict]: + """ + Parse SEARCH/REPLACE blocks from a model completion. + + Each block is a '### filepath' line followed by a + <<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. Triple-backtick code + fences around the block are accepted but not required. + Returns a list of {'file_path', 'search', 'replace'} dicts. + """ + edits: List[dict] = [] + lines = completion.split('\n') + n = len(lines) + i = 0 + while i < n: + if '<<<<<<< SEARCH' not in lines[i]: + i += 1 + continue + + # Walk back to the most recent '### filepath' line, but don't cross a + # previous '>>>>>>> REPLACE' marker (that path belongs to an earlier edit). + file_path = None + for j in range(i - 1, -1, -1): + if '>>>>>>> REPLACE' in lines[j]: + break + stripped = lines[j].strip() + if stripped.startswith('###'): + file_path = stripped[3:].strip() + break + if not file_path: + i += 1 + continue + + search_start = i + 1 + sep = replace_end = None + for k in range(search_start, n): + if sep is None and '=======' in lines[k]: + sep = k + elif sep is not None and '>>>>>>> REPLACE' in lines[k]: + replace_end = k + break + + if sep is None or replace_end is None: + i += 1 + continue + + edits.append({ + 'file_path': file_path, + 'search': '\n'.join(lines[search_start:sep]), + 'replace': '\n'.join(lines[sep + 1:replace_end]), + }) + i = replace_end + 1 + return edits diff --git a/pipelinerl/domains/swe/reward.py b/pipelinerl/domains/swe/reward.py new file mode 100644 index 00000000..22cc611b --- /dev/null +++ b/pipelinerl/domains/swe/reward.py @@ -0,0 +1,163 @@ +import difflib +import logging +import re +from typing import Dict, List, Tuple, TypedDict + +from unidiff import PatchSet +from unidiff.errors import UnidiffParseError + +logger = logging.getLogger(__name__) + + +class FormatError(Exception): + pass + + +class ChangeSimilarity(TypedDict): + path: str + pred_change: str + oracle_change: str + similarity: float + + +def parse_patch_for_gold_files(patch_text: str) -> List[str]: + """Extract modified file paths from a unified diff patch.""" + if not patch_text: + return [] + return re.findall(r'^--- a/(.+)$', patch_text, re.MULTILINE) + + +def generate_unified_diff(old_code: str, new_code: str, n_context: int = 3) -> str: + diff = difflib.unified_diff( + old_code.splitlines(), + new_code.splitlines(), + fromfile="old", + tofile="new", + lineterm="", + n=n_context, + ) + try: + next(diff) + next(diff) + return "\n".join(diff) + except StopIteration: + return "" + + +def apply_edits_to_files( + file_contents: Dict[str, str], + edits: List[Dict], + silent: bool = False, +) -> Dict[str, str]: + new_content_dict = dict(file_contents) + for edit in edits: + file_path = edit.get('file_path', '') + search_text = edit.get('search', '') + replace_text = edit.get('replace', '') + + if not silent and search_text == replace_text: + raise FormatError("Search and replace blocks are identical") + + if file_path not in new_content_dict: + if not silent: + raise FormatError(f"File {file_path} not found in file_contents") + logger.warning("File %s not found in file_contents", file_path) + continue + + current_content = new_content_dict[file_path] + if search_text not in current_content: + if not silent: + raise FormatError(f"Search text not found in {file_path}: {search_text}") + logger.warning("Search text not found in %s", file_path) + continue + + new_content_dict[file_path] = current_content.replace(search_text, replace_text, 1) + + return new_content_dict + + +def get_normalized_patch( + code_context: Dict[str, str], + new_content_dict: Dict[str, str], +) -> Dict[str, str]: + patch_dict = {} + for path, new_content in new_content_dict.items(): + old_content = code_context.get(path, "") + patch = generate_unified_diff(old_content, new_content) + if patch: + patch_dict[path] = patch + return patch_dict + + +def get_filelevel_diff(patch_text: str) -> Dict[str, str]: + try: + patch = PatchSet(patch_text) + except UnidiffParseError: + return {} + except Exception as e: + logger.warning("Unexpected unidiff parsing error: %s", e) + return {} + + result = {} + for patchfile in patch: + body = "\n".join(str(hunk).strip() for hunk in patchfile) + result[patchfile.path] = body.strip() + return result + + +def compute_change_similarities( + pred_patch: Dict[str, str], + oracle_patch: Dict[str, str], +) -> List[ChangeSimilarity]: + all_file_paths = set(oracle_patch) | set(pred_patch) + similarities = [] + for path in all_file_paths: + pred_change = pred_patch.get(path, "") + oracle_change = oracle_patch.get(path, "") + if not oracle_change or not pred_change: + change_similarity = 0.0 + else: + change_similarity = difflib.SequenceMatcher( + None, pred_change, oracle_change, autojunk=False + ).ratio() + similarities.append(ChangeSimilarity( + path=path, + pred_change=pred_change, + oracle_change=oracle_change, + similarity=change_similarity, + )) + return similarities + + +def calculate_precise_reward( + file_contents: Dict[str, str], + oracle_patch_text: str, + predicted_edits: List[Dict], +) -> Tuple[float, Dict]: + try: + if not predicted_edits: + raise FormatError("No valid search blocks found") + + oracle_patch = get_filelevel_diff(oracle_patch_text) + pred_new_content = apply_edits_to_files(file_contents, predicted_edits) + pred_patch = get_normalized_patch(file_contents, pred_new_content) + similarities = compute_change_similarities(pred_patch, oracle_patch) + + if not similarities: + assert not oracle_patch and not pred_patch + return 1.0, {"similarities": []} + + reward = sum(s["similarity"] for s in similarities) / len(similarities) + return reward, { + "similarities": similarities, + "num_files_changed": len(similarities), + "oracle_files": list(oracle_patch.keys()), + "predicted_files": list(pred_patch.keys()), + } + + except FormatError as e: + # logger.warning("Format error in reward calculation: %s", str(e)) + return 0.0, {"format_error": True, "error_message": str(e)} + except Exception as e: + logger.error("Unexpected error in reward calculation: %s", e) + return 0.0, {"error": str(e)} diff --git a/pipelinerl/domains/swe/rollouts.py b/pipelinerl/domains/swe/rollouts.py new file mode 100644 index 00000000..c724266c --- /dev/null +++ b/pipelinerl/domains/swe/rollouts.py @@ -0,0 +1,71 @@ +import logging +import math +import time + +import aiohttp +from omegaconf import DictConfig + +from pipelinerl.async_llm import llm_async_generate, make_training_text +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.rollouts import BaseMetrics, RolloutResult + +from pipelinerl.domains.swe.repair import build_messages, parse_edits +from pipelinerl.domains.swe.reward import calculate_precise_reward + +logger = logging.getLogger(__name__) + + +class SWEMetrics(BaseMetrics): + format_error: bool = False + + +async def generate_swe_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + file_contents = problem.get("file_contents", {}) + problem_statement = problem["problem_statement"] + gold_patch = problem.get("patch", "") + + messages = build_messages(problem_statement, file_contents) + prompt = Prompt(messages=messages) + + time_start = time.time() + llm_call = await llm_async_generate(llm, prompt, session) + latency = time.time() - time_start + + raw_output = llm_call.output.content or "" + edits = parse_edits(raw_output) + + reward, reward_meta = calculate_precise_reward(file_contents, gold_patch, edits) + + if hasattr(cfg.actor, 'discount_factor'): + reward *= cfg.actor.discount_factor ** llm_call.output_length_tokens + + trace = make_training_text(llm, llm_call) + trace.reward = reward if not math.isnan(reward) else 0.0 + trace.metadata["stage"] = "repair" + trace.metadata["dataset"] = problem.get("dataset") + trace.metadata["problem_id"] = problem.get("id") or problem.get("instance_id") + + format_error = bool(reward_meta.get("format_error", False)) + success_threshold = getattr(cfg.actor, 'success_threshold', 0.8) + success = (not format_error) and reward >= success_threshold + + metrics = SWEMetrics( + reward=trace.reward, + success=success, + no_error=not format_error, + no_answer=len(edits) == 0, + format_error=format_error, + ) + + return RolloutResult( + training_texts=[trace], + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + domain="swe", + ) diff --git a/pipelinerl/domains/swe/swe_preprocessor.py b/pipelinerl/domains/swe/swe_preprocessor.py new file mode 100644 index 00000000..253dcf38 --- /dev/null +++ b/pipelinerl/domains/swe/swe_preprocessor.py @@ -0,0 +1,287 @@ +""" +SWE dataset preprocessing utility. + +Clones repos, extracts gold_file_contents at base_commit, applies token-count +filtering, and optionally computes per-repo file stats for BM25 retrieval. +Run this once per dataset before training; output is saved as a HuggingFace disk dataset. + +Usage: + python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess +""" + +import json +import logging +import math +import os +import re +import shutil +from collections import Counter +from pathlib import Path +from typing import Dict, List, Optional, Set + +import git +import hydra +from datasets import Dataset, load_dataset, load_from_disk +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + + +class RepoManager: + """Clones or updates GitHub repositories to a local directory.""" + + def __init__(self, repos_base_dir: str): + self.repos_base_dir = Path(repos_base_dir) + os.makedirs(self.repos_base_dir, exist_ok=True) + + def clone_or_update_repo(self, repo_name: str) -> Path: + repo_url = f"https://github.com/{repo_name}.git" + local_path = self.repos_base_dir / repo_name.replace("/", "_") + + try: + if local_path.exists() and (local_path / ".git").exists(): + repo = git.Repo(local_path) + repo.remotes.origin.fetch() + else: + if local_path.exists(): + logger.warning("Removing broken directory %s before fresh clone", local_path) + shutil.rmtree(local_path) + git.Repo.clone_from(repo_url, local_path) + return local_path + except Exception as e: + logger.error("Failed to process repo %s: %s", repo_name, e) + raise + + def clone_or_update_repos(self, repo_names: List[str]) -> Dict[str, Path]: + results: Dict[str, Path] = {} + failed: List[str] = [] + for name in tqdm(repo_names, desc="Cloning/updating repos"): + try: + path = self.clone_or_update_repo(name) + if (path / ".git").exists(): + results[name] = path + else: + failed.append(name) + except Exception: + failed.append(name) + if failed: + logger.warning("Failed to clone/update %d repos: %s", len(set(failed)), list(set(failed))) + return results + + +class SwePreprocessor: + """Processes a SWE-style HuggingFace dataset into a training-ready disk dataset.""" + + SOURCE_EXTENSIONS = { + ".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", ".cs", ".php", + ".rb", ".go", ".rs", ".kt", ".scala", ".swift", ".m", ".mm", ".sh", ".bash", + ".zsh", ".fish", ".pl", ".r", ".R", ".sql", ".html", ".css", ".scss", ".sass", + ".less", ".vue", ".jsx", ".tsx", ".json", ".yaml", ".yml", ".xml", ".toml", + ".ini", ".cfg", ".conf", ".properties", ".gradle", ".cmake", ".make", + ".dockerfile", ".md", ".rst", ".txt", ".lock", ".requirements", ".pyx", ".ipynb", + ".pxd", ".pyi", ".pxi.in", + } + SKIP_DIRS = { + "test", "tests", "__pycache__", ".git", ".svn", ".hg", "node_modules", + ".pytest_cache", ".tox", "venv", ".env", "dist", ".idea", ".vscode", + "target", "out", "bin", "obj", ".gradle", "coverage", ".coverage", + ".nyc_output", "htmlcov", + } + + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.repos_base_dir = Path(cfg.repo_path) + self.dataset_path = Path(cfg.dataset_path) + self.min_token_threshold = cfg.min_token_threshold + self.max_token_threshold = cfg.max_token_threshold + self.num_map_processes = cfg.num_map_processes + self.tokenizer_model = cfg.tokenizer_model + self.repo_manager = RepoManager(self.repos_base_dir) + self.file_stats_cache: Dict = {} + self.tokenizer = self._init_tokenizer() + + def _init_tokenizer(self): + try: + tok = AutoTokenizer.from_pretrained(self.tokenizer_model) + logger.info("Tokenizer loaded: %s", self.tokenizer_model) + return tok + except Exception as e: + logger.warning("Could not load tokenizer (%s): %s — token filtering disabled", self.tokenizer_model, e) + return None + + # ── File helpers ───────────────────────────────────────────────────────── + + def _is_source_file(self, filepath: str) -> bool: + path = Path(filepath) + if any(p.lower() in self.SKIP_DIRS for p in path.parts): + return False + if path.suffix.lower() in self.SOURCE_EXTENSIONS: + return True + return not path.suffix and path.name.lower() in { + "makefile", "dockerfile", "readme", "license", "changelog", + "requirements", "pipfile", "gemfile", "rakefile", + } + + def _get_file_content(self, repo_path: Path, commit: str, filepath: str) -> Optional[str]: + if not (repo_path / ".git").exists(): + return None + try: + return git.Repo(repo_path).git.show(f"{commit}:{filepath}") + except Exception: + return None + + def _parse_patch(self, patch: str) -> List[str]: + return re.findall(r"^--- a/(.+)$", patch or "", re.MULTILINE) + + # ── Token filtering ─────────────────────────────────────────────────────── + + def _filter_by_token_count(self, example: Dict) -> bool: + if not self.tokenizer: + return True + try: + contents = json.loads(example.get("gold_file_contents", "{}")) + if not contents: + return False + text = " ".join(contents.values()) + " " + (example.get("problem_statement") or "") + n = len(self.tokenizer.encode(text, add_special_tokens=False)) + if self.min_token_threshold is not None and n < self.min_token_threshold: + return False + if self.max_token_threshold is not None and n > self.max_token_threshold: + return False + return True + except Exception as e: + logger.error("Token filter error: %s", e) + return False + + # ── File stats (for BM25, optional) ────────────────────────────────────── + + def _tokenize_content(self, content: str) -> Counter: + tokens = re.findall(r"[a-zA-Z0-9_]+", content.lower()) + return Counter(t for t in tokens if 2 <= len(t) <= 50 and not (t.isdigit() and len(t) > 4)) + + def _get_all_file_stats(self, repo_path: Path, commit: str) -> Dict[str, Dict]: + key = (str(repo_path), commit) + if key in self.file_stats_cache: + return self.file_stats_cache[key] + + stats: Dict[str, Dict] = {} + if not (repo_path / ".git").exists(): + self.file_stats_cache[key] = stats + return stats + + try: + repo = git.Repo(repo_path) + files = repo.git.execute(["git", "ls-tree", "-r", "--name-only", commit]) + for filepath in (files.strip().split("\n") if files.strip() else []): + if not filepath or not self._is_source_file(filepath): + continue + try: + content = repo.git.show(f"{commit}:{filepath}") + content.encode("utf-8") # skip binary + term_counts = self._tokenize_content(content) + stats[filepath] = {"path": filepath, "length": len(content), "term_counts": dict(term_counts)} + except Exception: + continue + except Exception as e: + logger.error("File stats error for %s@%s: %s", repo_path, commit, e) + + self.file_stats_cache[key] = stats + return stats + + # ── Per-example processing passes ──────────────────────────────────────── + + def _extract_gold_file_contents_only(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict: + repo = example.get("repo") + commit = example.get("base_commit") + patch = example.get("patch") + contents: Dict[str, str] = {} + + repo_path = repo_paths.get(repo) if repo else None + if repo_path and commit and patch: + for fp in self._parse_patch(patch): + c = self._get_file_content(repo_path, commit, fp) + if c is not None: + contents[fp] = c + + example["gold_file_contents"] = json.dumps(contents) + return example + + def _add_file_stats_to_example(self, example: Dict, repo_paths: Dict[str, Path]) -> Dict: + example["all_file_stats"] = json.dumps({}) + example["_invalid_example"] = True + + repo = example.get("repo") + commit = example.get("base_commit") + patch = example.get("patch") + repo_path = repo_paths.get(repo) if repo else None + + if not (repo_path and commit and patch): + return example + + gold_files = self._parse_patch(patch) + if any(not self._is_source_file(f) for f in gold_files): + return example + + all_stats = self._get_all_file_stats(repo_path, commit) + example["all_file_stats"] = json.dumps(all_stats) + example["_invalid_example"] = any(f not in all_stats for f in gold_files) + return example + + # ── Main entry point ───────────────────────────────────────────────────── + + def process(self) -> Set[str]: + if self.dataset_path.exists() and not self.cfg.force_reprocess: + logger.info("Found existing dataset at %s, loading from disk", self.dataset_path) + try: + ds = load_from_disk(str(self.dataset_path)) + return set(ds["repo"]) + except Exception as e: + logger.warning("Could not load existing dataset (%s), reprocessing", e) + + logger.info("Loading %s (split=%s) from Hub", self.cfg.hf_dataset_name, self.cfg.hf_split_name) + dataset = load_dataset(self.cfg.hf_dataset_name, split=self.cfg.hf_split_name) + logger.info("Loaded %d examples", len(dataset)) + + unique_repos = set(dataset["repo"]) + repo_paths = self.repo_manager.clone_or_update_repos(list(unique_repos)) + + # Pass 1: extract gold file contents (fast) + dataset = dataset.map( + lambda ex: self._extract_gold_file_contents_only(ex, repo_paths), + batched=False, num_proc=self.num_map_processes, + load_from_cache_file=False, desc="Extracting gold file contents", + ) + + # Token-count filtering (before expensive file stats) + if self.tokenizer: + before = len(dataset) + dataset = dataset.filter(self._filter_by_token_count, desc="Token-count filtering") + logger.info("Token filtering: %d → %d examples", before, len(dataset)) + + # Pass 2: compute file stats (expensive, only for surviving examples) + dataset = dataset.map( + lambda ex: self._add_file_stats_to_example(ex, repo_paths), + batched=False, num_proc=self.num_map_processes, + load_from_cache_file=False, desc="Computing file stats", + ) + + before = len(dataset) + dataset = dataset.filter(lambda ex: not ex.get("_invalid_example", True), desc="Filtering invalid examples") + dataset = dataset.remove_columns(["_invalid_example"]) + logger.info("Validity filtering: %d → %d examples", before, len(dataset)) + + logger.info("Saving to %s", self.dataset_path) + dataset.save_to_disk(str(self.dataset_path)) + logger.info("Done. %d examples saved.", len(dataset)) + return unique_repos + + +@hydra.main(config_path="../../../../conf", config_name="swe/preprocess", version_base="1.3.2") +def main(cfg: DictConfig): + SwePreprocessor(cfg).process() + + +if __name__ == "__main__": + main() diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 2ba28c5e..e61dc730 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -186,7 +186,6 @@ def rl_step( model_inputs = { "input_ids": batch.input_ids, "attention_mask": batch.attention_mask, - "labels": batch.labels, } if batch.is_packed: model_inputs["position_ids"] = batch.position_ids diff --git a/pipelinerl/finetune/value_model.py b/pipelinerl/finetune/value_model.py index 7c8f1136..a14214f7 100644 --- a/pipelinerl/finetune/value_model.py +++ b/pipelinerl/finetune/value_model.py @@ -15,8 +15,6 @@ class CausalLMOutputWithValue(ModelOutput): Output type for causal language models with an additional value head. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*): - Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`): Prediction scores of the language modeling head. value (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): @@ -29,7 +27,6 @@ class CausalLMOutputWithValue(ModelOutput): Attention weights after the attention softmax. """ - loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None value: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -76,7 +73,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -93,7 +89,6 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, @@ -107,7 +102,6 @@ def forward( values = self.value_head(hidden_states) return CausalLMOutputWithValue( - loss=outputs.loss, logits=outputs.logits, value=values, past_key_values=outputs.past_key_values, @@ -120,7 +114,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.pretrained_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs ) - + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -132,14 +126,14 @@ def save_pretrained( ): """Save model and value head separately.""" import os - + if state_dict is None: state_dict = self.state_dict() - + # Extract pretrained model and value head state dicts pretrained_model_state_dict = {} value_head_state_dict = {} - + for key, value in state_dict.items(): if key.startswith("value_head."): # Remove the "value_head." prefix @@ -154,7 +148,7 @@ def save_pretrained( f"Unexpected key in state dict: {key}. " "Expected keys should start with 'value_head.' or 'pretrained_model.'." ) - + # Save the pretrained model which can be easily loaded by vllm, etc. self.pretrained_model.save_pretrained( save_directory, @@ -164,7 +158,7 @@ def save_pretrained( safe_serialization=safe_serialization, **kwargs, ) - + # Save value head separately if is_main_process: value_head_path = os.path.join(save_directory, "value_head.pt") diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index 632b760e..bd49ba04 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -110,7 +110,7 @@ def connect_to_redis(config: RedisConfig): logger.debug(f"Trying to connect to Redis server at {config.host}:{config.port}") client = redis.Redis(host=config.host, port=config.port) client.ping() - logger.info(f"Connected to Redis server") + logger.info("Connected to Redis server") return client except (redis.exceptions.TimeoutError, redis.ConnectionError) as e: logger.warning(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.") diff --git a/pyproject.toml b/pyproject.toml index 81ecbd42..568fc65f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,9 @@ coding = [ fn_calling = [ "bfcl-eval>=2025.6.8", ] +swe = [ + "unidiff>=0.7.5", +] logic = [ # i3-logic verification code is vendored in pipelinerl/domains/logic/i3_logic/ # (source: https://github.com/PrimeIntellect/i3-logic)