From 7c8a758cc430501c3f068b53a55bc37fc159dce1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:17:10 +0000 Subject: [PATCH 1/2] Initial plan From b1b9cd565836156d44fbadc2855c5e925e633e68 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 13:36:43 +0000 Subject: [PATCH 2/2] Implement dataset split methodology: SplitMetadata, benchmark split enforcement, CLI split commands Agent-Logs-Url: https://github.com/VoDaiLocz/Lock-ART./sessions/8c1bc471-0b09-4dd2-9406-3208cb4888c3 Co-authored-by: VoDaiLocz <88762074+VoDaiLocz@users.noreply.github.com> --- src/auralock/benchmarks/__init__.py | 18 + src/auralock/benchmarks/antidreambooth.py | 4 +- src/auralock/benchmarks/splits.py | 27 + src/auralock/cli.py | 184 ++++++ src/auralock/core/splits.py | 268 ++++++++ src/auralock/services/protection.py | 77 ++- src/tests/test_splits.py | 770 ++++++++++++++++++++++ 7 files changed, 1345 insertions(+), 3 deletions(-) create mode 100644 src/auralock/benchmarks/splits.py create mode 100644 src/auralock/core/splits.py create mode 100644 src/tests/test_splits.py diff --git a/src/auralock/benchmarks/__init__.py b/src/auralock/benchmarks/__init__.py index 1a88588..88839a6 100644 --- a/src/auralock/benchmarks/__init__.py +++ b/src/auralock/benchmarks/__init__.py @@ -28,8 +28,26 @@ build_lora_train_command, evaluate_lora_preflight, ) +from auralock.benchmarks.splits import ( + SplitMetadata, + SplitType, + compute_split_hash, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_split_manifest, + warn_non_test_split, +) __all__ = [ + "SplitMetadata", + "SplitType", + "compute_split_hash", + "create_random_split", + "load_split_manifest", + "save_split_manifest", + "validate_split_manifest", + "warn_non_test_split", "DEFAULT_ANTI_DREAMBOOTH_CLASS_PROMPT", "DEFAULT_ANTI_DREAMBOOTH_INFER_SCRIPT", "DEFAULT_ANTI_DREAMBOOTH_INSTANCE_PROMPT", diff --git a/src/auralock/benchmarks/antidreambooth.py b/src/auralock/benchmarks/antidreambooth.py index 42eae32..e8a7c61 100644 --- a/src/auralock/benchmarks/antidreambooth.py +++ b/src/auralock/benchmarks/antidreambooth.py @@ -315,7 +315,9 @@ def run( notes = [ "This benchmark follows the Anti-DreamBooth paper-style set_A/set_B/set_C split.", - "set_A is retained as a clean reference split, set_B is treated as the published split, and set_C is preserved as holdout metadata.", + "set_A is retained as a clean reference split, set_B is treated as the published (training) split, " + "and set_C is the held-out validation split used to measure out-of-sample protection effectiveness.", + "Evaluate mimicry success on set_C (holdout) images—never on set_B—to avoid in-sample bias.", "AuraLock still uses its own protection pipeline; this workflow is a benchmark alignment layer, not an ASPL/FSMG reproduction.", ] diff --git a/src/auralock/benchmarks/splits.py b/src/auralock/benchmarks/splits.py new file mode 100644 index 0000000..03fc44f --- /dev/null +++ b/src/auralock/benchmarks/splits.py @@ -0,0 +1,27 @@ +"""Dataset split methodology utilities for reproducible benchmark evaluation. + +This module re-exports from :mod:`auralock.core.splits` for convenience. +The canonical implementation lives in the core package to avoid circular imports. +""" + +from auralock.core.splits import ( + SplitMetadata, + SplitType, + compute_split_hash, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_split_manifest, + warn_non_test_split, +) + +__all__ = [ + "SplitMetadata", + "SplitType", + "compute_split_hash", + "create_random_split", + "load_split_manifest", + "save_split_manifest", + "validate_split_manifest", + "warn_non_test_split", +] diff --git a/src/auralock/cli.py b/src/auralock/cli.py index 0ccd520..f401f1f 100644 --- a/src/auralock/cli.py +++ b/src/auralock/cli.py @@ -26,7 +26,13 @@ AntiDreamBoothSubjectBenchmarkHarness, DockerLoraBenchmarkConfig, LoraBenchmarkHarness, + SplitMetadata, + SplitType, build_docker_lora_benchmark_plan, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_split_manifest, ) from auralock.core.image import save_image from auralock.services import ( @@ -42,6 +48,13 @@ ) console = Console() +split_app = typer.Typer( + name="split", + help="Dataset split management for reproducible benchmark evaluation.", + add_completion=False, +) +app.add_typer(split_app, name="split") + def _to_builtin(value: Any) -> Any: """Convert report payloads into JSON-friendly values.""" @@ -634,6 +647,18 @@ def benchmark( recursive: bool = typer.Option( False, "--recursive", help="Benchmark nested directories recursively" ), + split_manifest: Path | None = typer.Option( + None, + "--split-manifest", + help="Path to a JSON split manifest (from 'auralock split create'). " + "When provided, only images in the declared split are benchmarked.", + ), + split_type: str = typer.Option( + "test", + "--split-type", + help="Split to use from the manifest: train, val, test, or dev. " + "Using a non-test split emits a bias warning.", + ), report: Path | None = typer.Option( None, "--report", @@ -652,6 +677,42 @@ def benchmark( console.print("[red]Error:[/red] At least one profile is required.") raise typer.Exit(1) + resolved_split_metadata: SplitMetadata | None = None + if split_manifest is not None: + try: + all_splits = load_split_manifest(split_manifest) + except FileNotFoundError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc + except (json.JSONDecodeError, KeyError, ValueError) as exc: + console.print(f"[red]Error:[/red] Could not parse split manifest: {exc}") + raise typer.Exit(1) from exc + + try: + requested_type = SplitType(split_type) + except ValueError as exc: + valid = ", ".join(t.value for t in SplitType) + console.print( + f"[red]Error:[/red] Invalid --split-type '{split_type}'. " + f"Valid values: {valid}" + ) + raise typer.Exit(1) from exc + + if requested_type not in all_splits: + console.print( + f"[red]Error:[/red] Split type '{split_type}' not found in manifest. " + f"Available: {', '.join(k.value for k in all_splits)}" + ) + raise typer.Exit(1) + + resolved_split_metadata = all_splits[requested_type] + if requested_type != SplitType.TEST: + console.print( + f"[yellow]⚠ WARNING:[/yellow] Benchmarking on '{split_type}' split. " + "Results may be overfit to this split. " + "Use --split-type test for final evaluation." + ) + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -668,11 +729,13 @@ def benchmark( input_path, profiles=profile_names, recursive=recursive, + split_metadata=resolved_split_metadata, ) else: summary = service.benchmark_file( input_path, profiles=profile_names, + split_metadata=resolved_split_metadata, ) except ValueError as exc: console.print(f"[red]Error:[/red] {exc}") @@ -681,6 +744,13 @@ def benchmark( console.print(_render_profile_summary_table(summary.profile_summaries)) + if resolved_split_metadata is not None: + console.print( + f"[dim]Split: {resolved_split_metadata.split_type.value} | " + f"Images: {len(resolved_split_metadata.image_ids)} | " + f"Hash: {resolved_split_metadata.split_hash}[/dim]" + ) + # Print prominent warning about protection scores in benchmark summary from auralock.core.metrics import ( PROTECTION_SCORE_DISCLAIMER, @@ -1164,6 +1234,120 @@ def benchmark_lora_docker( raise typer.Exit(exc.returncode or 1) from exc +@split_app.command("create") +def split_create( + dataset_dir: Path = typer.Argument( + ..., help="Directory containing images to split" + ), + output: Path = typer.Option( + ..., "--output", help="Path to write the JSON split manifest" + ), + dataset_name: str = typer.Option( + "dataset", "--dataset-name", help="Human-readable dataset name" + ), + dataset_version: str = typer.Option( + "1.0", "--dataset-version", help="Dataset version string" + ), + train_ratio: float = typer.Option( + 0.7, "--train-ratio", help="Fraction of images for the train split" + ), + val_ratio: float = typer.Option( + 0.15, "--val-ratio", help="Fraction of images for the validation split" + ), + test_ratio: float = typer.Option( + 0.15, "--test-ratio", help="Fraction of images for the test split" + ), + seed: int = typer.Option(42, "--seed", help="Random seed for reproducibility"), + recursive: bool = typer.Option( + False, "--recursive", help="Scan sub-directories for images" + ), +) -> None: + """Create a train/val/test split manifest from a dataset directory.""" + from auralock.core.image import SUPPORTED_EXTENSIONS + + if not dataset_dir.exists() or not dataset_dir.is_dir(): + console.print(f"[red]Error:[/red] Dataset directory not found: {dataset_dir}") + raise typer.Exit(1) + + iterator = dataset_dir.rglob("*") if recursive else dataset_dir.glob("*") + image_paths = [ + p + for p in sorted(iterator) + if p.is_file() and p.suffix.lower() in SUPPORTED_EXTENSIONS + ] + if not image_paths: + console.print(f"[red]Error:[/red] No supported images found in: {dataset_dir}") + raise typer.Exit(1) + + try: + splits = create_random_split( + image_paths, + dataset_name=dataset_name, + dataset_version=dataset_version, + train_ratio=train_ratio, + val_ratio=val_ratio, + test_ratio=test_ratio, + random_seed=seed, + ) + except ValueError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc + + save_split_manifest(splits, output) + + table = Table(title="Split Manifest Summary") + table.add_column("Split", style="cyan") + table.add_column("Images", style="green") + table.add_column("Ratio", style="yellow") + table.add_column("Hash", style="magenta") + for split_type, split_meta in splits.items(): + table.add_row( + split_type.value, + str(len(split_meta.image_ids)), + f"{split_meta.split_ratio.get(split_type.value, 0):.2f}", + split_meta.split_hash, + ) + console.print(table) + console.print(f"[green]Split manifest saved to:[/green] {output}") + + +@split_app.command("validate") +def split_validate( + manifest: Path = typer.Argument(..., help="Path to the JSON split manifest"), +) -> None: + """Validate a split manifest for data leakage and hash integrity.""" + try: + splits = load_split_manifest(manifest) + except FileNotFoundError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc + except (json.JSONDecodeError, KeyError, ValueError) as exc: + console.print(f"[red]Error:[/red] Could not parse split manifest: {exc}") + raise typer.Exit(1) from exc + + issues = validate_split_manifest(splits) + + table = Table(title="Split Validation") + table.add_column("Split", style="cyan") + table.add_column("Images", style="green") + table.add_column("Hash", style="magenta") + for split_type, split_meta in splits.items(): + table.add_row( + split_type.value, + str(len(split_meta.image_ids)), + split_meta.split_hash, + ) + console.print(table) + + if issues: + for issue in issues: + console.print(f"[red]Issue:[/red] {issue}") + console.print(f"[red]{len(issues)} validation issue(s) found.[/red]") + raise typer.Exit(1) + else: + console.print("[green]✓ No issues found. Split manifest is valid.[/green]") + + @app.command() def demo() -> None: """Run a quick demo using a synthetic image.""" diff --git a/src/auralock/core/splits.py b/src/auralock/core/splits.py new file mode 100644 index 0000000..4781f41 --- /dev/null +++ b/src/auralock/core/splits.py @@ -0,0 +1,268 @@ +"""Dataset split methodology utilities for reproducible benchmark evaluation.""" + +from __future__ import annotations + +import hashlib +import json +import random +import warnings +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + + +class SplitType(Enum): + """Dataset split role used in reproducible benchmark evaluation.""" + + TRAIN = "train" + VALIDATION = "val" + TEST = "test" + DEVELOPMENT = "dev" + + +@dataclass +class SplitMetadata: + """Metadata describing one split of a dataset. + + Tracks image membership, split method, and a deterministic hash so that + split assignments can be audited and reproduced across benchmark runs. + """ + + split_type: SplitType + dataset_name: str + dataset_version: str + split_hash: str + image_ids: list[str] + split_method: str + split_ratio: dict[str, float] + random_seed: int | None = None + + def verify_no_leakage(self, other: SplitMetadata) -> bool: + """Return True when no images are shared between this split and *other*.""" + return set(self.image_ids).isdisjoint(set(other.image_ids)) + + def to_dict(self) -> dict[str, object]: + """Serialize to a JSON-friendly dictionary.""" + return { + "split_type": self.split_type.value, + "dataset_name": self.dataset_name, + "dataset_version": self.dataset_version, + "split_hash": self.split_hash, + "image_ids": self.image_ids, + "split_method": self.split_method, + "split_ratio": self.split_ratio, + "random_seed": self.random_seed, + } + + @classmethod + def from_dict(cls, data: dict[str, object]) -> SplitMetadata: + """Reconstruct from a serialized dictionary.""" + return cls( + split_type=SplitType(data["split_type"]), + dataset_name=str(data["dataset_name"]), + dataset_version=str(data["dataset_version"]), + split_hash=str(data["split_hash"]), + image_ids=list(data["image_ids"]), + split_method=str(data["split_method"]), + split_ratio=dict(data["split_ratio"]), + random_seed=data.get("random_seed"), + ) + + +def compute_split_hash( + image_ids: list[str], + split_type: SplitType, + seed: int | None, +) -> str: + """Return a short deterministic SHA-256 hash for a split assignment.""" + content = json.dumps( + { + "split_type": split_type.value, + "image_ids": sorted(image_ids), + "seed": seed, + }, + sort_keys=True, + ) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + +def create_random_split( + image_paths: list[Path], + *, + dataset_name: str = "dataset", + dataset_version: str = "1.0", + train_ratio: float = 0.7, + val_ratio: float = 0.15, + test_ratio: float = 0.15, + random_seed: int = 42, +) -> dict[SplitType, SplitMetadata]: + """Create deterministic random train/val/test splits with full metadata. + + Parameters + ---------- + image_paths: + Paths to all images in the dataset. + dataset_name: + Human-readable name for the dataset (stored in metadata). + dataset_version: + Version string for the dataset (stored in metadata). + train_ratio, val_ratio, test_ratio: + Fractions for each split; must sum to 1.0. + random_seed: + Seed for reproducible shuffling. + + Returns + ------- + dict mapping each :class:`SplitType` to its :class:`SplitMetadata`. + """ + if not image_paths: + raise ValueError("image_paths must not be empty.") + + total = train_ratio + val_ratio + test_ratio + if abs(total - 1.0) > 1e-6: + raise ValueError( + f"Split ratios must sum to 1.0, got {train_ratio} + {val_ratio} + {test_ratio} = {total:.6f}." + ) + + rng = random.Random(random_seed) + shuffled = rng.sample(image_paths, len(image_paths)) + n = len(shuffled) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + partitions: list[tuple[SplitType, list[Path]]] = [ + (SplitType.TRAIN, shuffled[:n_train]), + (SplitType.VALIDATION, shuffled[n_train : n_train + n_val]), + (SplitType.TEST, shuffled[n_train + n_val :]), + ] + + split_ratio = { + SplitType.TRAIN.value: train_ratio, + SplitType.VALIDATION.value: val_ratio, + SplitType.TEST.value: test_ratio, + } + + splits: dict[SplitType, SplitMetadata] = {} + for split_type, images in partitions: + ids = [str(p) for p in images] + splits[split_type] = SplitMetadata( + split_type=split_type, + dataset_name=dataset_name, + dataset_version=dataset_version, + split_hash=compute_split_hash(ids, split_type, random_seed), + image_ids=ids, + split_method="random", + split_ratio=split_ratio, + random_seed=random_seed, + ) + + return splits + + +def save_split_manifest( + splits: dict[SplitType, SplitMetadata], + output_path: Path, +) -> None: + """Persist split assignments to a JSON manifest for reproducibility. + + Parameters + ---------- + splits: + Mapping returned by :func:`create_random_split` (or built manually). + output_path: + Destination path for the JSON manifest. + """ + manifest = { + split_type.value: split_meta.to_dict() + for split_type, split_meta in splits.items() + } + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + json.dumps(manifest, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + + +def load_split_manifest(manifest_path: Path) -> dict[SplitType, SplitMetadata]: + """Load a previously saved split manifest. + + Parameters + ---------- + manifest_path: + Path to a JSON manifest written by :func:`save_split_manifest`. + + Returns + ------- + Mapping of :class:`SplitType` to :class:`SplitMetadata`. + + Raises + ------ + FileNotFoundError + If *manifest_path* does not exist. + """ + manifest_path = Path(manifest_path) + if not manifest_path.exists(): + raise FileNotFoundError(f"Split manifest not found: {manifest_path}") + data = json.loads(manifest_path.read_text(encoding="utf-8")) + return { + SplitType(key): SplitMetadata.from_dict(value) for key, value in data.items() + } + + +def validate_split_manifest( + splits: dict[SplitType, SplitMetadata], +) -> list[str]: + """Validate splits for data leakage and hash integrity. + + Parameters + ---------- + splits: + Mapping of split types to their metadata (e.g. from + :func:`load_split_manifest` or :func:`create_random_split`). + + Returns + ------- + A list of human-readable issue descriptions. An empty list means no + problems were detected. + """ + issues: list[str] = [] + split_list = list(splits.values()) + + for i, split_a in enumerate(split_list): + for split_b in split_list[i + 1 :]: + if not split_a.verify_no_leakage(split_b): + overlap = set(split_a.image_ids) & set(split_b.image_ids) + issues.append( + f"Data leakage detected between '{split_a.split_type.value}' and " + f"'{split_b.split_type.value}' splits: " + f"{len(overlap)} shared image(s)." + ) + + for split_meta in splits.values(): + computed = compute_split_hash( + split_meta.image_ids, split_meta.split_type, split_meta.random_seed + ) + if computed != split_meta.split_hash: + issues.append( + f"Hash mismatch for '{split_meta.split_type.value}' split. " + "Split assignment may have been modified after creation." + ) + + return issues + + +def warn_non_test_split(split_type: SplitType) -> None: + """Emit a :class:`UserWarning` when benchmarking on a non-test split. + + The warning is attributed to the direct caller of this function + (``stacklevel=2``), which is typically a benchmark method. + """ + if split_type != SplitType.TEST: + warnings.warn( + f"Benchmarking on '{split_type.value}' split. " + "Results may be overfit to this split. " + "Use the TEST split for final evaluation to avoid bias.", + UserWarning, + stacklevel=2, + ) diff --git a/src/auralock/services/protection.py b/src/auralock/services/protection.py index 73a1e18..23596dd 100644 --- a/src/auralock/services/protection.py +++ b/src/auralock/services/protection.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from dataclasses import dataclass from pathlib import Path from time import perf_counter @@ -27,6 +28,7 @@ ) from auralock.core.pipeline import load_default_model, resolve_device from auralock.core.profiles import ProtectionConfig, resolve_protection_config +from auralock.core.splits import SplitMetadata, warn_non_test_split from auralock.core.style import load_default_style_feature_extractor @@ -200,6 +202,7 @@ class BenchmarkSummary: image_count: int entries: list[BenchmarkEntry] profile_summaries: dict[str, dict[str, object]] + split_metadata: SplitMetadata | None = None def to_report_dict(self) -> dict[str, object]: """Serialize the full benchmark report.""" @@ -209,6 +212,11 @@ def to_report_dict(self) -> dict[str, object]: "image_count": self.image_count, "entries": [entry.to_report_dict() for entry in self.entries], "profile_summaries": self.profile_summaries, + "split_metadata": ( + self.split_metadata.to_dict() + if self.split_metadata is not None + else None + ), "validation_metadata": { "is_validated": False, "validation_status": "not_validated", @@ -718,6 +726,7 @@ def _collect_benchmark_entries( *, input_path: Path, profiles: tuple[str, ...], + split_metadata: SplitMetadata | None = None, ) -> BenchmarkSummary: """Benchmark the requested profiles on a list of images.""" if not image_paths: @@ -774,6 +783,7 @@ def _collect_benchmark_entries( image_count=len(image_paths), entries=entries, profile_summaries=profile_summaries, + split_metadata=split_metadata, ) def benchmark_file( @@ -781,17 +791,45 @@ def benchmark_file( input_path: str | Path, *, profiles: tuple[str, ...] = ("safe", "balanced", "strong"), + split_metadata: SplitMetadata | None = None, ) -> BenchmarkSummary: - """Benchmark one image against multiple named profiles.""" + """Benchmark one image against multiple named profiles. + + Parameters + ---------- + input_path: + Path to the image to benchmark. + profiles: + Named protection profiles to evaluate. + split_metadata: + Optional split context. When provided the image is validated against + the declared split and a :class:`UserWarning` is raised when the + split type is not :attr:`SplitType.TEST`. + """ candidate = Path(input_path) if not candidate.exists() or not candidate.is_file(): raise ValueError("input_path must be an existing image file.") if candidate.suffix.lower() not in SUPPORTED_EXTENSIONS: raise ValueError("input_path must point to a supported image file.") + + if split_metadata is not None: + split_ids = set(split_metadata.image_ids) + if ( + str(candidate) not in split_ids + and str(candidate.resolve()) not in split_ids + ): + raise ValueError( + f"Image '{candidate}' is not in the declared " + f"'{split_metadata.split_type.value}' split. " + "Benchmarking an image outside its split may cause data leakage." + ) + warn_non_test_split(split_metadata.split_type) + return self._collect_benchmark_entries( [candidate], input_path=candidate, profiles=profiles, + split_metadata=split_metadata, ) def benchmark_directory( @@ -800,8 +838,23 @@ def benchmark_directory( *, profiles: tuple[str, ...] = ("safe", "balanced", "strong"), recursive: bool = False, + split_metadata: SplitMetadata | None = None, ) -> BenchmarkSummary: - """Benchmark all supported images in a directory across profiles.""" + """Benchmark all supported images in a directory across profiles. + + Parameters + ---------- + input_dir: + Directory containing images to benchmark. + profiles: + Named protection profiles to evaluate. + recursive: + When True, scan sub-directories as well. + split_metadata: + Optional split context. When provided only images whose paths are + listed in the split are benchmarked, and a :class:`UserWarning` is + raised when the split type is not :attr:`SplitType.TEST`. + """ input_path = Path(input_dir) if not input_path.exists() or not input_path.is_dir(): raise ValueError("input_dir must be an existing directory.") @@ -812,10 +865,30 @@ def benchmark_directory( for candidate in sorted(iterator) if candidate.is_file() and candidate.suffix.lower() in SUPPORTED_EXTENSIONS ] + + if split_metadata is not None: + split_ids = set(split_metadata.image_ids) + filtered = [ + p + for p in image_paths + if str(p) in split_ids or str(p.resolve()) in split_ids + ] + excluded = len(image_paths) - len(filtered) + if excluded > 0: + warnings.warn( + f"{excluded} image(s) in '{input_path}' were excluded because " + f"they are not part of the declared '{split_metadata.split_type.value}' split.", + UserWarning, + stacklevel=2, + ) + image_paths = filtered + warn_non_test_split(split_metadata.split_type) + return self._collect_benchmark_entries( image_paths, input_path=input_path, profiles=profiles, + split_metadata=split_metadata, ) def protect_directory( diff --git a/src/tests/test_splits.py b/src/tests/test_splits.py new file mode 100644 index 0000000..40ecd2f --- /dev/null +++ b/src/tests/test_splits.py @@ -0,0 +1,770 @@ +"""Tests for dataset split methodology utilities.""" + +from __future__ import annotations + +import json +import warnings +from pathlib import Path + +import pytest +from typer.testing import CliRunner + +from auralock.benchmarks.splits import ( + SplitMetadata, + SplitType, + compute_split_hash, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_split_manifest, + warn_non_test_split, +) +from auralock.cli import app + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_images(base: Path, names: list[str]) -> list[Path]: + """Create empty stub files at *base* and return their paths.""" + base.mkdir(parents=True, exist_ok=True) + paths = [] + for name in names: + p = base / name + p.write_bytes(b"stub") + paths.append(p) + return paths + + +# --------------------------------------------------------------------------- +# SplitType +# --------------------------------------------------------------------------- + + +def test_split_type_values(): + assert SplitType.TRAIN.value == "train" + assert SplitType.VALIDATION.value == "val" + assert SplitType.TEST.value == "test" + assert SplitType.DEVELOPMENT.value == "dev" + + +# --------------------------------------------------------------------------- +# SplitMetadata +# --------------------------------------------------------------------------- + + +def test_split_metadata_verify_no_leakage_disjoint(): + a = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash="abc", + image_ids=["a.png", "b.png"], + split_method="random", + split_ratio={"train": 0.7}, + ) + b = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash="def", + image_ids=["c.png", "d.png"], + split_method="random", + split_ratio={"test": 0.15}, + ) + assert a.verify_no_leakage(b) is True + + +def test_split_metadata_verify_no_leakage_overlap(): + a = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash="abc", + image_ids=["a.png", "b.png"], + split_method="random", + split_ratio={"train": 0.7}, + ) + b = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash="def", + image_ids=["b.png", "c.png"], + split_method="random", + split_ratio={"test": 0.15}, + ) + assert a.verify_no_leakage(b) is False + + +def test_split_metadata_to_dict_and_from_dict(): + meta = SplitMetadata( + split_type=SplitType.VALIDATION, + dataset_name="mydata", + dataset_version="2.0", + split_hash="deadbeef", + image_ids=["x.png"], + split_method="random", + split_ratio={"val": 0.15}, + random_seed=99, + ) + roundtripped = SplitMetadata.from_dict(meta.to_dict()) + assert roundtripped.split_type == SplitType.VALIDATION + assert roundtripped.dataset_name == "mydata" + assert roundtripped.dataset_version == "2.0" + assert roundtripped.split_hash == "deadbeef" + assert roundtripped.image_ids == ["x.png"] + assert roundtripped.split_method == "random" + assert roundtripped.split_ratio == {"val": 0.15} + assert roundtripped.random_seed == 99 + + +# --------------------------------------------------------------------------- +# create_random_split +# --------------------------------------------------------------------------- + + +def test_create_random_split_basic(tmp_path: Path): + images = _make_images(tmp_path, [f"{i}.png" for i in range(10)]) + splits = create_random_split(images, random_seed=0) + + assert set(splits) == {SplitType.TRAIN, SplitType.VALIDATION, SplitType.TEST} + total = sum(len(s.image_ids) for s in splits.values()) + assert total == 10 + + train_meta = splits[SplitType.TRAIN] + assert train_meta.split_type == SplitType.TRAIN + assert train_meta.split_method == "random" + assert train_meta.random_seed == 0 + + +def test_create_random_split_ratios_must_sum_to_one(tmp_path: Path): + images = _make_images(tmp_path, ["a.png", "b.png"]) + with pytest.raises(ValueError, match="sum to 1.0"): + create_random_split(images, train_ratio=0.5, val_ratio=0.3, test_ratio=0.3) + + +def test_create_random_split_empty_raises(tmp_path: Path): + with pytest.raises(ValueError, match="must not be empty"): + create_random_split([]) + + +def test_create_random_split_no_leakage_across_splits(tmp_path: Path): + images = _make_images(tmp_path, [f"{i}.png" for i in range(20)]) + splits = create_random_split(images) + split_list = list(splits.values()) + for i, a in enumerate(split_list): + for b in split_list[i + 1 :]: + assert a.verify_no_leakage( + b + ), f"Leakage between {a.split_type.value} and {b.split_type.value}" + + +def test_create_random_split_deterministic(tmp_path: Path): + images = _make_images(tmp_path, [f"{i}.png" for i in range(10)]) + splits1 = create_random_split(images, random_seed=42) + splits2 = create_random_split(images, random_seed=42) + assert splits1[SplitType.TEST].image_ids == splits2[SplitType.TEST].image_ids + + +def test_create_random_split_hash_stored_in_metadata(tmp_path: Path): + images = _make_images(tmp_path, [f"{i}.png" for i in range(6)]) + splits = create_random_split(images, random_seed=7) + for split_type, meta in splits.items(): + expected = compute_split_hash(meta.image_ids, split_type, 7) + assert meta.split_hash == expected + + +# --------------------------------------------------------------------------- +# save_split_manifest / load_split_manifest +# --------------------------------------------------------------------------- + + +def test_save_and_load_split_manifest_roundtrip(tmp_path: Path): + images = _make_images(tmp_path / "data", [f"{i}.png" for i in range(8)]) + splits = create_random_split(images, random_seed=1) + + manifest_path = tmp_path / "splits.json" + save_split_manifest(splits, manifest_path) + + assert manifest_path.exists() + loaded = load_split_manifest(manifest_path) + + assert set(loaded) == {SplitType.TRAIN, SplitType.VALIDATION, SplitType.TEST} + for split_type in splits: + assert ( + splits[split_type].image_ids == loaded[split_type].image_ids + ), f"image_ids mismatch for {split_type}" + assert splits[split_type].split_hash == loaded[split_type].split_hash + + +def test_load_split_manifest_missing_file_raises(tmp_path: Path): + with pytest.raises(FileNotFoundError): + load_split_manifest(tmp_path / "nonexistent.json") + + +def test_save_split_manifest_creates_parent_dirs(tmp_path: Path): + images = _make_images(tmp_path, ["a.png"]) + splits = create_random_split(images, train_ratio=1.0, val_ratio=0.0, test_ratio=0.0) + output = tmp_path / "deep" / "nested" / "splits.json" + save_split_manifest(splits, output) + assert output.exists() + + +# --------------------------------------------------------------------------- +# validate_split_manifest +# --------------------------------------------------------------------------- + + +def test_validate_split_manifest_clean(tmp_path: Path): + images = _make_images(tmp_path, [f"{i}.png" for i in range(10)]) + splits = create_random_split(images) + issues = validate_split_manifest(splits) + assert issues == [] + + +def test_validate_split_manifest_detects_leakage(): + shared_id = "shared.png" + meta_train = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([shared_id, "a.png"], SplitType.TRAIN, None), + image_ids=[shared_id, "a.png"], + split_method="manual", + split_ratio={"train": 0.7}, + ) + meta_test = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([shared_id, "b.png"], SplitType.TEST, None), + image_ids=[shared_id, "b.png"], + split_method="manual", + split_ratio={"test": 0.15}, + ) + issues = validate_split_manifest( + {SplitType.TRAIN: meta_train, SplitType.TEST: meta_test} + ) + assert any("leakage" in issue.lower() for issue in issues) + + +def test_validate_split_manifest_detects_hash_mismatch(): + meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash="tampered_hash", + image_ids=["a.png", "b.png"], + split_method="random", + split_ratio={"test": 0.15}, + random_seed=42, + ) + issues = validate_split_manifest({SplitType.TEST: meta}) + assert any("hash mismatch" in issue.lower() for issue in issues) + + +# --------------------------------------------------------------------------- +# warn_non_test_split +# --------------------------------------------------------------------------- + + +def test_warn_non_test_split_emits_warning_for_train(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + warn_non_test_split(SplitType.TRAIN) + assert len(caught) == 1 + assert issubclass(caught[0].category, UserWarning) + assert "train" in str(caught[0].message).lower() + + +def test_warn_non_test_split_silent_for_test(): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + warn_non_test_split(SplitType.TEST) + assert len(caught) == 0 + + +# --------------------------------------------------------------------------- +# ProtectionService.benchmark_file / benchmark_directory with split_metadata +# --------------------------------------------------------------------------- + + +def test_benchmark_file_with_test_split_metadata(tmp_path: Path): + """benchmark_file should succeed silently when image is in the TEST split.""" + from PIL import Image as PILImage + + from auralock.core.pipeline import ImageNetModelAdapter + from auralock.services import ProtectionService + + from .test_pipeline import RecordingClassifier + from .test_stylecloak import DummyStyleFeatureExtractor + + img_path = tmp_path / "art.png" + PILImage.new("RGB", (32, 32), color="blue").save(img_path) + + # Manually build a TEST split with this image + test_meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([str(img_path)], SplitType.TEST, None), + image_ids=[str(img_path)], + split_method="manual", + split_ratio={"test": 1.0}, + ) + + service = ProtectionService( + model=ImageNetModelAdapter(RecordingClassifier()), + style_feature_extractor=DummyStyleFeatureExtractor(), + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + summary = service.benchmark_file( + img_path, profiles=("safe",), split_metadata=test_meta + ) + + assert summary.split_metadata is not None + assert summary.split_metadata.split_type == SplitType.TEST + # No warnings should fire for TEST split + assert all(not issubclass(w.category, UserWarning) for w in caught) + + +def test_benchmark_file_warns_on_train_split(tmp_path: Path): + """benchmark_file should emit a UserWarning when using a TRAIN split.""" + from PIL import Image as PILImage + + from auralock.core.pipeline import ImageNetModelAdapter + from auralock.services import ProtectionService + + from .test_pipeline import RecordingClassifier + from .test_stylecloak import DummyStyleFeatureExtractor + + img_path = tmp_path / "art.png" + PILImage.new("RGB", (32, 32), color="red").save(img_path) + + train_meta = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([str(img_path)], SplitType.TRAIN, None), + image_ids=[str(img_path)], + split_method="manual", + split_ratio={"train": 1.0}, + ) + + service = ProtectionService( + model=ImageNetModelAdapter(RecordingClassifier()), + style_feature_extractor=DummyStyleFeatureExtractor(), + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + summary = service.benchmark_file( + img_path, profiles=("safe",), split_metadata=train_meta + ) + + assert summary.split_metadata is not None + user_warnings = [w for w in caught if issubclass(w.category, UserWarning)] + assert len(user_warnings) == 1 + assert "train" in str(user_warnings[0].message).lower() + + +def test_benchmark_file_raises_when_image_not_in_split(tmp_path: Path): + """benchmark_file should raise ValueError if the image is not in the split.""" + from PIL import Image as PILImage + + from auralock.core.pipeline import ImageNetModelAdapter + from auralock.services import ProtectionService + + from .test_pipeline import RecordingClassifier + from .test_stylecloak import DummyStyleFeatureExtractor + + img_path = tmp_path / "art.png" + PILImage.new("RGB", (32, 32), color="green").save(img_path) + + test_meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash="x", + image_ids=["other_image.png"], # does NOT contain img_path + split_method="manual", + split_ratio={"test": 1.0}, + ) + + service = ProtectionService( + model=ImageNetModelAdapter(RecordingClassifier()), + style_feature_extractor=DummyStyleFeatureExtractor(), + ) + + with pytest.raises(ValueError, match="not in the declared"): + service.benchmark_file(img_path, profiles=("safe",), split_metadata=test_meta) + + +def test_benchmark_summary_to_report_dict_includes_split_metadata(tmp_path: Path): + """BenchmarkSummary.to_report_dict should include split_metadata when set.""" + from PIL import Image as PILImage + + from auralock.core.pipeline import ImageNetModelAdapter + from auralock.services import ProtectionService + + from .test_pipeline import RecordingClassifier + from .test_stylecloak import DummyStyleFeatureExtractor + + img_path = tmp_path / "art.png" + PILImage.new("RGB", (32, 32), color="white").save(img_path) + + test_meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([str(img_path)], SplitType.TEST, None), + image_ids=[str(img_path)], + split_method="manual", + split_ratio={"test": 1.0}, + ) + + service = ProtectionService( + model=ImageNetModelAdapter(RecordingClassifier()), + style_feature_extractor=DummyStyleFeatureExtractor(), + ) + summary = service.benchmark_file( + img_path, profiles=("safe",), split_metadata=test_meta + ) + report = summary.to_report_dict() + + assert report["split_metadata"] is not None + assert report["split_metadata"]["split_type"] == "test" + assert report["split_metadata"]["dataset_name"] == "d" + + +# --------------------------------------------------------------------------- +# CLI: split create +# --------------------------------------------------------------------------- + + +def test_split_create_cli_writes_manifest(tmp_path: Path): + """CLI 'split create' should write a valid JSON manifest.""" + from PIL import Image as PILImage + + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + for i in range(10): + PILImage.new("RGB", (8, 8), color="red").save(dataset_dir / f"{i}.png") + + manifest_path = tmp_path / "splits.json" + runner = CliRunner() + result = runner.invoke( + app, + [ + "split", + "create", + str(dataset_dir), + "--output", + str(manifest_path), + "--seed", + "42", + ], + ) + + assert result.exit_code == 0, result.output + assert manifest_path.exists() + data = json.loads(manifest_path.read_text(encoding="utf-8")) + assert "train" in data + assert "val" in data + assert "test" in data + total_images = sum(len(v["image_ids"]) for v in data.values()) + assert total_images == 10 + + +def test_split_create_cli_fails_on_missing_dir(tmp_path: Path): + runner = CliRunner() + result = runner.invoke( + app, + [ + "split", + "create", + str(tmp_path / "nonexistent"), + "--output", + str(tmp_path / "splits.json"), + ], + ) + assert result.exit_code != 0 + assert "not found" in result.output.lower() or "Error" in result.output + + +def test_split_create_cli_fails_on_bad_ratios(tmp_path: Path): + from PIL import Image as PILImage + + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + PILImage.new("RGB", (8, 8)).save(dataset_dir / "a.png") + + runner = CliRunner() + result = runner.invoke( + app, + [ + "split", + "create", + str(dataset_dir), + "--output", + str(tmp_path / "splits.json"), + "--train-ratio", + "0.5", + "--val-ratio", + "0.3", + "--test-ratio", + "0.3", + ], + ) + assert result.exit_code != 0 + + +# --------------------------------------------------------------------------- +# CLI: split validate +# --------------------------------------------------------------------------- + + +def test_split_validate_cli_clean_manifest(tmp_path: Path): + """CLI 'split validate' should exit 0 for a valid manifest.""" + from PIL import Image as PILImage + + dataset_dir = tmp_path / "data" + dataset_dir.mkdir() + images = [] + for i in range(6): + p = dataset_dir / f"{i}.png" + PILImage.new("RGB", (8, 8)).save(p) + images.append(p) + + splits = create_random_split(images) + manifest_path = tmp_path / "splits.json" + save_split_manifest(splits, manifest_path) + + runner = CliRunner() + result = runner.invoke(app, ["split", "validate", str(manifest_path)]) + + assert result.exit_code == 0, result.output + assert "valid" in result.output.lower() or "No issues" in result.output + + +def test_split_validate_cli_fails_on_leakage(tmp_path: Path): + """CLI 'split validate' should exit 1 when leakage is detected.""" + shared = "shared.png" + meta_train = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([shared], SplitType.TRAIN, None), + image_ids=[shared], + split_method="manual", + split_ratio={"train": 0.7}, + ) + meta_test = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([shared], SplitType.TEST, None), + image_ids=[shared], + split_method="manual", + split_ratio={"test": 0.15}, + ) + manifest_path = tmp_path / "bad_splits.json" + save_split_manifest( + {SplitType.TRAIN: meta_train, SplitType.TEST: meta_test}, manifest_path + ) + + runner = CliRunner() + result = runner.invoke(app, ["split", "validate", str(manifest_path)]) + + assert result.exit_code != 0 + assert "leakage" in result.output.lower() + + +def test_split_validate_cli_fails_on_missing_file(tmp_path: Path): + runner = CliRunner() + result = runner.invoke(app, ["split", "validate", str(tmp_path / "missing.json")]) + assert result.exit_code != 0 + + +# --------------------------------------------------------------------------- +# CLI: benchmark with --split-manifest / --split-type +# --------------------------------------------------------------------------- + + +def test_benchmark_cli_with_split_manifest(monkeypatch, tmp_path: Path): + """The benchmark CLI should pass split metadata to the service when a manifest is given.""" + from PIL import Image as PILImage + + captured: dict[str, object] = {} + + class FakeSummary: + profile_summaries = { + "safe": { + "image_count": 1, + "avg_psnr_db": 38.0, + "avg_ssim": 0.95, + "avg_protection_score": 10.0, + "avg_runtime_sec": 0.5, + } + } + split_metadata = None + + def to_report_dict(self): + return { + "input_path": str(tmp_path), + "image_count": 1, + "entries": [], + "profile_summaries": self.profile_summaries, + "split_metadata": None, + } + + class FakeService: + def benchmark_file(self, input_path, **kwargs): + captured["split_metadata"] = kwargs.get("split_metadata") + return FakeSummary() + + def benchmark_directory(self, input_path, **kwargs): + captured["split_metadata"] = kwargs.get("split_metadata") + return FakeSummary() + + monkeypatch.setattr("auralock.cli.ProtectionService", FakeService) + + # Build manifest with the image in the test split + img_path = tmp_path / "img" / "a.png" + img_path.parent.mkdir(parents=True) + PILImage.new("RGB", (8, 8)).save(img_path) + + test_meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([str(img_path)], SplitType.TEST, 42), + image_ids=[str(img_path)], + split_method="manual", + split_ratio={"test": 1.0}, + random_seed=42, + ) + manifest_path = tmp_path / "splits.json" + save_split_manifest({SplitType.TEST: test_meta}, manifest_path) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "benchmark", + str(img_path), + "--profiles", + "safe", + "--split-manifest", + str(manifest_path), + "--split-type", + "test", + ], + ) + + assert result.exit_code == 0, result.output + assert captured.get("split_metadata") is not None + assert captured["split_metadata"].split_type == SplitType.TEST + + +def test_benchmark_cli_warns_on_train_split(monkeypatch, tmp_path: Path): + """The benchmark CLI should print a warning when using a train split.""" + from PIL import Image as PILImage + + class FakeSummary: + profile_summaries = { + "safe": { + "image_count": 1, + "avg_psnr_db": 38.0, + "avg_ssim": 0.95, + "avg_protection_score": 10.0, + "avg_runtime_sec": 0.5, + } + } + split_metadata = None + + def to_report_dict(self): + return { + "input_path": "", + "image_count": 1, + "entries": [], + "profile_summaries": {}, + "split_metadata": None, + } + + class FakeService: + def benchmark_file(self, input_path, **kwargs): + return FakeSummary() + + def benchmark_directory(self, input_path, **kwargs): + return FakeSummary() + + monkeypatch.setattr("auralock.cli.ProtectionService", FakeService) + + img_path = tmp_path / "a.png" + PILImage.new("RGB", (8, 8)).save(img_path) + + train_meta = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="d", + dataset_version="1", + split_hash=compute_split_hash([str(img_path)], SplitType.TRAIN, 1), + image_ids=[str(img_path)], + split_method="manual", + split_ratio={"train": 1.0}, + random_seed=1, + ) + manifest_path = tmp_path / "splits.json" + save_split_manifest({SplitType.TRAIN: train_meta}, manifest_path) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "benchmark", + str(img_path), + "--profiles", + "safe", + "--split-manifest", + str(manifest_path), + "--split-type", + "train", + ], + ) + + assert result.exit_code == 0, result.output + assert "WARNING" in result.output or "overfit" in result.output.lower() + + +def test_benchmark_cli_invalid_split_type(tmp_path: Path): + """The benchmark CLI should exit with error for invalid --split-type.""" + from PIL import Image as PILImage + + img_path = tmp_path / "a.png" + PILImage.new("RGB", (8, 8)).save(img_path) + + splits = create_random_split( + [img_path], train_ratio=1.0, val_ratio=0.0, test_ratio=0.0 + ) + manifest_path = tmp_path / "splits.json" + save_split_manifest(splits, manifest_path) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "benchmark", + str(img_path), + "--profiles", + "safe", + "--split-manifest", + str(manifest_path), + "--split-type", + "invalid_type", + ], + ) + assert result.exit_code != 0