diff --git a/README.md b/README.md index 1fc90c3..14d4ac0 100644 --- a/README.md +++ b/README.md @@ -150,10 +150,19 @@ auralock batch ./.cache_ref/Anti-DreamBooth/data/n000050/set_B ./protected_subje --working-size 384 ^ --report reports/batch-collective.json -# Compare multiple profiles -auralock benchmark artwork.png --profiles safe,balanced,strong --report reports/benchmark.json +# Create a leak-free split manifest then benchmark on the test split +auralock split create ./artworks --output splits.json +auralock benchmark artwork.png ^ + --profiles safe,balanced,strong ^ + --split-manifest splits.json ^ + --split-type test ^ + --report reports/benchmark.json ``` +### Split Management + +Use `auralock split create` to generate deterministic train/val/test manifests (with hashes and ratios) and `auralock split validate` to assert that splits are non-overlapping. The `benchmark` command now requires a `--split-manifest` and `--split-type` so reported scores always reference a declared dataset split. + ### Optional Web UI ```bash diff --git a/src/auralock/benchmarks/__init__.py b/src/auralock/benchmarks/__init__.py index 1a88588..6d67b83 100644 --- a/src/auralock/benchmarks/__init__.py +++ b/src/auralock/benchmarks/__init__.py @@ -1,55 +1,21 @@ -"""Benchmark harnesses for real-world protection evaluation.""" +"""Benchmark helpers and split utilities.""" -from auralock.benchmarks.antidreambooth import ( - DEFAULT_ANTI_DREAMBOOTH_CLASS_PROMPT, - DEFAULT_ANTI_DREAMBOOTH_INFER_SCRIPT, - DEFAULT_ANTI_DREAMBOOTH_INSTANCE_PROMPT, - DEFAULT_ANTI_DREAMBOOTH_TRAIN_SCRIPT, - AntiDreamBoothBenchmarkManifest, - AntiDreamBoothSubjectBenchmarkHarness, - AntiDreamBoothSubjectLayout, - resolve_subject_layout, -) -from auralock.benchmarks.docker_runtime import ( - DEFAULT_BENCHMARK_BASE_IMAGE, - DEFAULT_COMPOSE_FILE, - DEFAULT_GPU_SMOKE_IMAGE, - DEFAULT_SERVICE_NAME, - DockerLoraBenchmarkConfig, - DockerLoraBenchmarkPlan, - build_docker_lora_benchmark_plan, -) -from auralock.benchmarks.lora import ( - LoraBenchmarkConfig, - LoraBenchmarkHarness, - LoraBenchmarkManifest, - LoraPreflightReport, - build_lora_infer_command, - build_lora_train_command, - evaluate_lora_preflight, +from auralock.benchmarks.splits import ( + SplitMetadata, + SplitType, + collect_supported_images, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_no_overlap, ) __all__ = [ - "DEFAULT_ANTI_DREAMBOOTH_CLASS_PROMPT", - "DEFAULT_ANTI_DREAMBOOTH_INFER_SCRIPT", - "DEFAULT_ANTI_DREAMBOOTH_INSTANCE_PROMPT", - "DEFAULT_ANTI_DREAMBOOTH_TRAIN_SCRIPT", - "DEFAULT_BENCHMARK_BASE_IMAGE", - "DEFAULT_COMPOSE_FILE", - "DEFAULT_GPU_SMOKE_IMAGE", - "DEFAULT_SERVICE_NAME", - "AntiDreamBoothBenchmarkManifest", - "AntiDreamBoothSubjectBenchmarkHarness", - "AntiDreamBoothSubjectLayout", - "DockerLoraBenchmarkConfig", - "DockerLoraBenchmarkPlan", - "LoraBenchmarkConfig", - "LoraBenchmarkHarness", - "LoraBenchmarkManifest", - "LoraPreflightReport", - "build_docker_lora_benchmark_plan", - "build_lora_infer_command", - "build_lora_train_command", - "evaluate_lora_preflight", - "resolve_subject_layout", + "SplitMetadata", + "SplitType", + "collect_supported_images", + "create_random_split", + "load_split_manifest", + "save_split_manifest", + "validate_no_overlap", ] diff --git a/src/auralock/benchmarks/antidreambooth.py b/src/auralock/benchmarks/antidreambooth.py index 42eae32..143f36e 100644 --- a/src/auralock/benchmarks/antidreambooth.py +++ b/src/auralock/benchmarks/antidreambooth.py @@ -315,7 +315,7 @@ def run( notes = [ "This benchmark follows the Anti-DreamBooth paper-style set_A/set_B/set_C split.", - "set_A is retained as a clean reference split, set_B is treated as the published split, and set_C is preserved as holdout metadata.", + "set_A is retained as a clean reference split, set_B is treated as the published split, and set_C must remain a held-out validation split.", "AuraLock still uses its own protection pipeline; this workflow is a benchmark alignment layer, not an ASPL/FSMG reproduction.", ] diff --git a/src/auralock/benchmarks/splits.py b/src/auralock/benchmarks/splits.py new file mode 100644 index 0000000..b64fa1f --- /dev/null +++ b/src/auralock/benchmarks/splits.py @@ -0,0 +1,226 @@ +"""Dataset split utilities to prevent benchmark data leakage.""" + +from __future__ import annotations + +import hashlib +import json +import random +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +from auralock.core.image import SUPPORTED_EXTENSIONS + + +def _normalize_image_id(path: str | Path) -> str: + """Normalize paths for stable manifest hashing and comparisons.""" + return str(Path(path).resolve()) + + +class SplitType(Enum): + """Canonical split types used across benchmarks.""" + + TRAIN = "train" + VALIDATION = "val" + TEST = "test" + DEVELOPMENT = "dev" + + +@dataclass(slots=True) +class SplitMetadata: + """Reproducible manifest for one split of a dataset.""" + + split_type: SplitType + dataset_name: str + dataset_version: str + split_method: str + split_ratio: dict[str, float] + image_ids: list[str] + random_seed: int | None = None + split_hash: str | None = None + dataset_root: str | None = None + + def __post_init__(self) -> None: + self.image_ids = [_normalize_image_id(path) for path in self.image_ids] + if len(set(self.image_ids)) != len(self.image_ids): + raise ValueError("image_ids must be unique for each split.") + if self.split_hash is None: + self.split_hash = self._compute_split_hash() + + def _compute_split_hash(self) -> str: + payload = { + "dataset_name": self.dataset_name, + "dataset_version": self.dataset_version, + "dataset_root": self.dataset_root, + "image_ids": sorted(self.image_ids), + "random_seed": self.random_seed, + "split_method": self.split_method, + "split_ratio": self.split_ratio, + "split_type": self.split_type.value, + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + return hashlib.sha256(encoded).hexdigest()[:16] + + @property + def normalized_image_ids(self) -> set[str]: + """Normalized image identifiers for membership validation.""" + return set(self.image_ids) + + def verify_no_leakage(self, other: SplitMetadata) -> bool: + """Check that no images overlap between two splits.""" + return self.normalized_image_ids.isdisjoint(other.normalized_image_ids) + + def contains_all(self, paths: Iterable[str | Path]) -> list[str]: + """Return any paths missing from the split.""" + normalized = {_normalize_image_id(path) for path in paths} + return sorted(normalized - self.normalized_image_ids) + + def to_dict(self) -> dict[str, object]: + return { + "split_type": self.split_type.value, + "dataset_name": self.dataset_name, + "dataset_version": self.dataset_version, + "dataset_root": self.dataset_root, + "split_method": self.split_method, + "split_ratio": self.split_ratio, + "random_seed": self.random_seed, + "split_hash": self.split_hash, + "image_ids": list(self.image_ids), + } + + @classmethod + def from_dict(cls, payload: dict[str, object]) -> SplitMetadata: + split_type = SplitType(str(payload["split_type"])) + return cls( + split_type=split_type, + dataset_name=str(payload["dataset_name"]), + dataset_version=str(payload.get("dataset_version", "unknown")), + dataset_root=( + str(payload["dataset_root"]) + if payload.get("dataset_root") is not None + else None + ), + split_method=str(payload.get("split_method", "manual")), + split_ratio=dict(payload.get("split_ratio", {})), + random_seed=payload.get("random_seed"), # type: ignore[arg-type] + split_hash=str(payload.get("split_hash") or ""), + image_ids=list(payload.get("image_ids", [])), # type: ignore[list-item] + ) + + +def _assert_ratio_sum(train_ratio: float, val_ratio: float, test_ratio: float) -> None: + total = train_ratio + val_ratio + test_ratio + if not abs(total - 1.0) < 1e-6: + raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0") + + +def collect_supported_images(dataset_root: Path) -> list[Path]: + """Collect supported images under a dataset root.""" + if not dataset_root.exists() or not dataset_root.is_dir(): + raise ValueError("dataset_root must be an existing directory.") + return [ + candidate + for candidate in sorted(dataset_root.rglob("*")) + if candidate.is_file() and candidate.suffix.lower() in SUPPORTED_EXTENSIONS + ] + + +def create_random_split( + image_paths: list[Path], + *, + train_ratio: float = 0.7, + val_ratio: float = 0.15, + test_ratio: float = 0.15, + random_seed: int = 42, + dataset_name: str = "dataset", + dataset_version: str = "v1", + split_method: str = "random", + dataset_root: Path | None = None, +) -> dict[SplitType, SplitMetadata]: + """Create a reproducible random split manifest.""" + if not image_paths: + raise ValueError("image_paths must contain at least one image.") + _assert_ratio_sum(train_ratio, val_ratio, test_ratio) + + rng = random.Random(random_seed) + shuffled = list(image_paths) + rng.shuffle(shuffled) + + n_train = int(len(shuffled) * train_ratio) + n_val = int(len(shuffled) * val_ratio) + train_images = shuffled[:n_train] + val_images = shuffled[n_train : n_train + n_val] + test_images = shuffled[n_train + n_val :] + + ratio = {"train": train_ratio, "val": val_ratio, "test": test_ratio} + root_str = str(dataset_root.resolve()) if dataset_root is not None else None + splits = { + SplitType.TRAIN: SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name=dataset_name, + dataset_version=dataset_version, + dataset_root=root_str, + split_method=split_method, + split_ratio=ratio, + random_seed=random_seed, + image_ids=[str(path.resolve()) for path in train_images], + ), + SplitType.VALIDATION: SplitMetadata( + split_type=SplitType.VALIDATION, + dataset_name=dataset_name, + dataset_version=dataset_version, + dataset_root=root_str, + split_method=split_method, + split_ratio=ratio, + random_seed=random_seed, + image_ids=[str(path.resolve()) for path in val_images], + ), + SplitType.TEST: SplitMetadata( + split_type=SplitType.TEST, + dataset_name=dataset_name, + dataset_version=dataset_version, + dataset_root=root_str, + split_method=split_method, + split_ratio=ratio, + random_seed=random_seed, + image_ids=[str(path.resolve()) for path in test_images], + ), + } + if not splits[SplitType.TEST].image_ids: + raise ValueError("test split would be empty; adjust ratios or dataset size.") + return splits + + +def save_split_manifest( + splits: dict[SplitType, SplitMetadata], output_path: Path +) -> None: + """Persist split metadata to a JSON manifest.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + payload = {split_type.value: meta.to_dict() for split_type, meta in splits.items()} + output_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def load_split_manifest(manifest_path: Path) -> dict[SplitType, SplitMetadata]: + """Load a split manifest from disk.""" + if not manifest_path.exists(): + raise FileNotFoundError(f"Split manifest not found: {manifest_path}") + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + splits: dict[SplitType, SplitMetadata] = {} + for key, value in payload.items(): + split_type = SplitType(key) + splits[split_type] = SplitMetadata.from_dict(value) + return splits + + +def validate_no_overlap(splits: dict[SplitType, SplitMetadata]) -> None: + """Raise when any split pair overlaps.""" + split_items = list(splits.items()) + for idx, (split_type, split_meta) in enumerate(split_items): + for other_type, other_meta in split_items[idx + 1 :]: + if not split_meta.verify_no_leakage(other_meta): + raise ValueError( + f"Split {split_type.value} overlaps with {other_type.value}." + ) diff --git a/src/auralock/cli.py b/src/auralock/cli.py index 0ccd520..1ede39e 100644 --- a/src/auralock/cli.py +++ b/src/auralock/cli.py @@ -15,19 +15,31 @@ from rich.table import Table from auralock import __version__ -from auralock.benchmarks import ( +from auralock.benchmarks.antidreambooth import ( DEFAULT_ANTI_DREAMBOOTH_CLASS_PROMPT, DEFAULT_ANTI_DREAMBOOTH_INFER_SCRIPT, DEFAULT_ANTI_DREAMBOOTH_INSTANCE_PROMPT, DEFAULT_ANTI_DREAMBOOTH_TRAIN_SCRIPT, + AntiDreamBoothSubjectBenchmarkHarness, +) +from auralock.benchmarks.docker_runtime import ( DEFAULT_BENCHMARK_BASE_IMAGE, DEFAULT_COMPOSE_FILE, DEFAULT_SERVICE_NAME, - AntiDreamBoothSubjectBenchmarkHarness, DockerLoraBenchmarkConfig, - LoraBenchmarkHarness, build_docker_lora_benchmark_plan, ) +from auralock.benchmarks.lora import ( + LoraBenchmarkHarness, +) +from auralock.benchmarks.splits import ( + SplitType, + collect_supported_images, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_no_overlap, +) from auralock.core.image import save_image from auralock.services import ( BatchProtectionSummary, @@ -40,6 +52,12 @@ help="Protect your artwork from AI style mimicry with a consistent production pipeline.", add_completion=False, ) +split_app = typer.Typer( + name="split", + help="Manage dataset split manifests for leak-free benchmarking.", + add_completion=False, +) +app.add_typer(split_app, name="split") console = Console() @@ -307,6 +325,80 @@ def _adaptive_thresholds_met( ) +@split_app.command("create") +def split_create( + dataset_root: Path = typer.Argument( + ..., help="Root directory containing the dataset to split" + ), + output: Path = typer.Option( + Path("split_manifest.json"), + "--output", + help="Output path for the split manifest JSON", + ), + train_ratio: float = typer.Option(0.7, "--train-ratio", help="Train ratio"), + val_ratio: float = typer.Option(0.15, "--val-ratio", help="Validation ratio"), + test_ratio: float = typer.Option(0.15, "--test-ratio", help="Test ratio"), + seed: int = typer.Option(42, "--seed", help="Random seed for deterministic splits"), + dataset_name: str | None = typer.Option( + None, + "--dataset-name", + help="Friendly dataset name to embed in the manifest", + ), + dataset_version: str = typer.Option( + "v1", "--dataset-version", help="Dataset version string" + ), +) -> None: + """Create a deterministic train/val/test split manifest.""" + try: + image_paths = collect_supported_images(dataset_root) + except ValueError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc + + resolved_name = dataset_name or dataset_root.name + try: + splits = create_random_split( + image_paths, + train_ratio=train_ratio, + val_ratio=val_ratio, + test_ratio=test_ratio, + random_seed=seed, + dataset_name=resolved_name, + dataset_version=dataset_version, + dataset_root=dataset_root, + ) + validate_no_overlap(splits) + save_split_manifest(splits, output) + except ValueError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc + + console.print( + f"[green]Saved split manifest[/green] to {output} " + f"(train={len(splits[SplitType.TRAIN].image_ids)}, " + f"val={len(splits[SplitType.VALIDATION].image_ids)}, " + f"test={len(splits[SplitType.TEST].image_ids)})" + ) + + +@split_app.command("validate") +def split_validate(manifest: Path = typer.Argument(..., help="Split manifest path")): + """Validate that a split manifest is well-formed and non-overlapping.""" + try: + splits = load_split_manifest(manifest) + validate_no_overlap(splits) + except Exception as exc: # noqa: BLE001 + console.print(f"[red]Invalid manifest:[/red] {exc}") + raise typer.Exit(1) from exc + + console.print("[green]Split manifest is valid and non-overlapping.[/green]") + for split_type, meta in splits.items(): + console.print( + f"- {split_type.value}: {len(meta.image_ids)} images " + f"(dataset={meta.dataset_name} v{meta.dataset_version})" + ) + + @app.command() def protect( input_path: Path = typer.Argument(..., help="Path to input image"), @@ -634,6 +726,16 @@ def benchmark( recursive: bool = typer.Option( False, "--recursive", help="Benchmark nested directories recursively" ), + split_manifest: Path = typer.Option( + ..., + "--split-manifest", + help="Split manifest JSON describing train/val/test assignments", + ), + split_type: str = typer.Option( + "test", + "--split-type", + help="Which split to benchmark: train, val, test, or dev", + ), report: Path | None = typer.Option( None, "--report", @@ -651,6 +753,22 @@ def benchmark( if not profile_names: console.print("[red]Error:[/red] At least one profile is required.") raise typer.Exit(1) + try: + requested_split = SplitType(split_type.lower()) + except ValueError as exc: + console.print("[red]Error:[/red] split-type must be train, val, test, or dev.") + raise typer.Exit(1) from exc + try: + split_manifest_map = load_split_manifest(split_manifest) + try: + split_metadata = split_manifest_map[requested_split] + except KeyError as exc: + raise KeyError( + f"Split '{requested_split.value}' not found in {split_manifest}" + ) from exc + except Exception as exc: # noqa: BLE001 + console.print(f"[red]Error loading split manifest:[/red] {exc}") + raise typer.Exit(1) from exc with Progress( SpinnerColumn(), @@ -668,17 +786,25 @@ def benchmark( input_path, profiles=profile_names, recursive=recursive, + split_metadata=split_metadata, ) else: summary = service.benchmark_file( input_path, profiles=profile_names, + split_metadata=split_metadata, ) except ValueError as exc: console.print(f"[red]Error:[/red] {exc}") raise typer.Exit(1) from exc progress.update(task, completed=True, description="Benchmark completed") + if split_metadata.split_type != SplitType.TEST: + console.print( + f"[yellow]Warning:[/yellow] Benchmarking on '{split_metadata.split_type.value}' split. " + "Use TEST for final reporting." + ) + console.print(_render_profile_summary_table(summary.profile_summaries)) # Print prominent warning about protection scores in benchmark summary diff --git a/src/auralock/services/protection.py b/src/auralock/services/protection.py index 73a1e18..7c4679f 100644 --- a/src/auralock/services/protection.py +++ b/src/auralock/services/protection.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from dataclasses import dataclass from pathlib import Path from time import perf_counter @@ -13,6 +14,7 @@ from PIL import Image from auralock.attacks import FGSM, PGD, StyleCloak +from auralock.benchmarks.splits import SplitMetadata, SplitType from auralock.core.image import ( SUPPORTED_EXTENSIONS, image_to_tensor, @@ -34,6 +36,8 @@ def _to_builtin(value: Any) -> Any: """Convert tensors, paths, and numpy scalars into JSON-friendly values.""" if isinstance(value, Path): return str(value) + if isinstance(value, SplitMetadata): + return value.to_dict() if isinstance(value, dict): return {str(key): _to_builtin(item) for key, item in value.items()} if isinstance(value, (list, tuple)): @@ -200,24 +204,30 @@ class BenchmarkSummary: image_count: int entries: list[BenchmarkEntry] profile_summaries: dict[str, dict[str, object]] + split_metadata: SplitMetadata | None = None def to_report_dict(self) -> dict[str, object]: """Serialize the full benchmark report.""" - return _to_builtin( - { - "input_path": self.input_path, - "image_count": self.image_count, - "entries": [entry.to_report_dict() for entry in self.entries], - "profile_summaries": self.profile_summaries, - "validation_metadata": { - "is_validated": False, - "validation_status": "not_validated", - "validation_method": None, - "validation_date": None, - "notes": "Protection metrics are proxy measurements not validated against real attacks like DreamBooth or LoRA.", - }, - } - ) + payload: dict[str, object] = { + "input_path": self.input_path, + "image_count": self.image_count, + "entries": [entry.to_report_dict() for entry in self.entries], + "profile_summaries": self.profile_summaries, + "validation_metadata": { + "is_validated": False, + "validation_status": "not_validated", + "validation_method": None, + "validation_date": None, + "notes": "Protection metrics are proxy measurements not validated against real attacks like DreamBooth or LoRA.", + }, + } + if self.split_metadata is not None: + payload["split_metadata"] = self.split_metadata + payload["split_type"] = self.split_metadata.split_type.value + payload["split_hash"] = self.split_metadata.split_hash + payload["dataset_name"] = self.split_metadata.dataset_name + payload["dataset_version"] = self.split_metadata.dataset_version + return _to_builtin(payload) class ProtectionService: @@ -712,17 +722,52 @@ def analyze_files( "protection_report": protection_report, } + def _validate_split_membership( + self, image_paths: list[Path], split_metadata: SplitMetadata + ) -> None: + """Ensure benchmark targets belong to the declared split.""" + missing = split_metadata.contains_all(image_paths) + if missing: + raise ValueError( + "Benchmark inputs are missing from the declared split manifest: " + + ", ".join(missing) + ) + if ( + split_metadata.dataset_root is not None + and split_metadata.dataset_root.strip() != "" + ): + root = Path(split_metadata.dataset_root).resolve() + outside_root = [ + str(path) + for path in image_paths + if not Path(path).resolve().is_relative_to(root) + ] + if outside_root: + raise ValueError( + "Benchmark inputs must be within dataset_root to avoid leakage: " + + ", ".join(outside_root) + ) + if split_metadata.split_type != SplitType.TEST: + warnings.warn( + f"Benchmarking on non-test split '{split_metadata.split_type.value}'. " + "Use TEST split for final evaluation to avoid optimistic bias.", + stacklevel=2, + ) + def _collect_benchmark_entries( self, image_paths: list[Path], *, input_path: Path, profiles: tuple[str, ...], + split_metadata: SplitMetadata, ) -> BenchmarkSummary: """Benchmark the requested profiles on a list of images.""" if not image_paths: raise ValueError("No supported images were found to benchmark.") + self._validate_split_membership(image_paths, split_metadata) + entries: list[BenchmarkEntry] = [] for image_path in image_paths: for profile in profiles: @@ -774,6 +819,7 @@ def _collect_benchmark_entries( image_count=len(image_paths), entries=entries, profile_summaries=profile_summaries, + split_metadata=split_metadata, ) def benchmark_file( @@ -781,8 +827,9 @@ def benchmark_file( input_path: str | Path, *, profiles: tuple[str, ...] = ("safe", "balanced", "strong"), + split_metadata: SplitMetadata, ) -> BenchmarkSummary: - """Benchmark one image against multiple named profiles.""" + """Benchmark one image against multiple named profiles with split tracking.""" candidate = Path(input_path) if not candidate.exists() or not candidate.is_file(): raise ValueError("input_path must be an existing image file.") @@ -792,6 +839,7 @@ def benchmark_file( [candidate], input_path=candidate, profiles=profiles, + split_metadata=split_metadata, ) def benchmark_directory( @@ -800,8 +848,9 @@ def benchmark_directory( *, profiles: tuple[str, ...] = ("safe", "balanced", "strong"), recursive: bool = False, + split_metadata: SplitMetadata, ) -> BenchmarkSummary: - """Benchmark all supported images in a directory across profiles.""" + """Benchmark all supported images in a directory across profiles with split tracking.""" input_path = Path(input_dir) if not input_path.exists() or not input_path.is_dir(): raise ValueError("input_dir must be an existing directory.") @@ -816,6 +865,7 @@ def benchmark_directory( image_paths, input_path=input_path, profiles=profiles, + split_metadata=split_metadata, ) def protect_directory( diff --git a/src/tests/test_antidreambooth_benchmark.py b/src/tests/test_antidreambooth_benchmark.py index 9e3695d..810c611 100644 --- a/src/tests/test_antidreambooth_benchmark.py +++ b/src/tests/test_antidreambooth_benchmark.py @@ -69,8 +69,16 @@ def protect_file(self, *args, **kwargs): } assert [job["profile"] for job in manifest.jobs] == ["clean", "safe"] assert manifest.jobs[0]["variant"] == "clean_published" - assert str(manifest.jobs[0]["published_dir"]).endswith("datasets\\published\\clean") - assert str(manifest.jobs[1]["published_dir"]).endswith("datasets\\published\\safe") + assert Path(manifest.jobs[0]["published_dir"]).parts[-3:] == ( + "datasets", + "published", + "clean", + ) + assert Path(manifest.jobs[1]["published_dir"]).parts[-3:] == ( + "datasets", + "published", + "safe", + ) def test_subject_benchmark_harness_executes_collective_protection_for_set_b( diff --git a/src/tests/test_benchmark.py b/src/tests/test_benchmark.py index 0f60af9..a06880c 100644 --- a/src/tests/test_benchmark.py +++ b/src/tests/test_benchmark.py @@ -7,6 +7,7 @@ from PIL import Image from typer.testing import CliRunner +from auralock.benchmarks.splits import SplitMetadata, SplitType, save_split_manifest from auralock.cli import app from .test_pipeline import RecordingClassifier @@ -18,6 +19,19 @@ def _create_image(path: Path, color: str) -> None: Image.new("RGB", (64, 48), color=color).save(path) +def _make_split_metadata(image_paths: list[Path]) -> SplitMetadata: + return SplitMetadata( + split_type=SplitType.TEST, + dataset_name="tmp", + dataset_version="v1", + split_method="manual", + split_ratio={"train": 0.0, "val": 0.0, "test": 1.0}, + random_seed=123, + dataset_root=str(image_paths[0].parent), + image_ids=[str(path.resolve()) for path in image_paths], + ) + + def test_protection_service_benchmark_directory_summarizes_profiles(tmp_path: Path): """Benchmark mode should produce per-profile aggregates over image inputs.""" from auralock.core.pipeline import ImageNetModelAdapter @@ -34,6 +48,7 @@ def test_protection_service_benchmark_directory_summarizes_profiles(tmp_path: Pa summary = service.benchmark_directory( input_dir, profiles=("safe", "balanced"), + split_metadata=_make_split_metadata([input_dir / "a.png", input_dir / "b.png"]), ) assert summary.image_count == 2 @@ -67,14 +82,16 @@ def to_report_dict(self): } class FakeService: - def benchmark_directory(self, input_path, **kwargs): + def benchmark_directory(self, input_path, *, split_metadata, **kwargs): calls["input_path"] = input_path calls["kwargs"] = kwargs + calls["split_metadata"] = split_metadata return FakeSummary() - def benchmark_file(self, input_path, **kwargs): + def benchmark_file(self, input_path, *, split_metadata, **kwargs): calls["input_path"] = input_path calls["kwargs"] = kwargs + calls["split_metadata"] = split_metadata return FakeSummary() monkeypatch.setattr("auralock.cli.ProtectionService", FakeService) @@ -83,6 +100,9 @@ def benchmark_file(self, input_path, **kwargs): input_dir = tmp_path / "input" input_dir.mkdir() report_path = tmp_path / "benchmark.json" + manifest_path = tmp_path / "split.json" + metadata = _make_split_metadata([input_dir / "a.png"]) + save_split_manifest({SplitType.TEST: metadata}, manifest_path) result = runner.invoke( app, @@ -91,6 +111,8 @@ def benchmark_file(self, input_path, **kwargs): str(input_dir), "--profiles", "safe,balanced", + "--split-manifest", + str(manifest_path), "--report", str(report_path), ], @@ -101,3 +123,4 @@ def benchmark_file(self, input_path, **kwargs): assert report_path.exists() assert calls["input_path"] == input_dir assert calls["kwargs"]["profiles"] == ("safe", "balanced") + assert calls["split_metadata"].split_type == SplitType.TEST diff --git a/src/tests/test_docker_runtime.py b/src/tests/test_docker_runtime.py index 7b60d05..3e1f4f6 100644 --- a/src/tests/test_docker_runtime.py +++ b/src/tests/test_docker_runtime.py @@ -12,7 +12,7 @@ def test_build_docker_lora_benchmark_plan_maps_workspace_paths(tmp_path: Path): """Docker planner should convert workspace-local paths into container paths.""" - from auralock.benchmarks import ( + from auralock.benchmarks.docker_runtime import ( DockerLoraBenchmarkConfig, build_docker_lora_benchmark_plan, ) @@ -80,7 +80,7 @@ def test_build_docker_lora_benchmark_plan_rejects_paths_outside_workspace( tmp_path: Path, ): """Docker planner should reject paths that are not covered by the workspace mount.""" - from auralock.benchmarks import ( + from auralock.benchmarks.docker_runtime import ( DockerLoraBenchmarkConfig, build_docker_lora_benchmark_plan, ) @@ -122,7 +122,7 @@ def test_build_docker_lora_benchmark_plan_requires_diffusers_model_dir( tmp_path: Path, ): """Docker planner should reject model paths that are not Diffusers directories.""" - from auralock.benchmarks import ( + from auralock.benchmarks.docker_runtime import ( DockerLoraBenchmarkConfig, build_docker_lora_benchmark_plan, ) @@ -159,7 +159,7 @@ def test_build_docker_lora_benchmark_plan_requires_diffusers_model_dir( def test_benchmark_lora_docker_cli_runs_build_and_execute(monkeypatch, tmp_path: Path): """CLI should orchestrate Docker build, GPU check, and benchmark execution.""" - from auralock.benchmarks import DockerLoraBenchmarkPlan + from auralock.benchmarks.docker_runtime import DockerLoraBenchmarkPlan recorded: list[list[str]] = [] diff --git a/src/tests/test_splits.py b/src/tests/test_splits.py new file mode 100644 index 0000000..90c7fbf --- /dev/null +++ b/src/tests/test_splits.py @@ -0,0 +1,88 @@ +"""Tests for dataset split utilities.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from auralock.benchmarks.splits import ( + SplitMetadata, + SplitType, + collect_supported_images, + create_random_split, + load_split_manifest, + save_split_manifest, + validate_no_overlap, +) + + +def _write_fake_images(root: Path, count: int) -> list[Path]: + root.mkdir(parents=True, exist_ok=True) + paths: list[Path] = [] + for index in range(count): + path = root / f"img_{index}.png" + path.write_bytes(b"fake") + paths.append(path) + return paths + + +def test_create_random_split_generates_hash_and_manifest(tmp_path: Path): + dataset_root = tmp_path / "dataset" + image_paths = _write_fake_images(dataset_root, 6) + + splits = create_random_split( + image_paths, + train_ratio=0.5, + val_ratio=0.25, + test_ratio=0.25, + random_seed=0, + dataset_name="demo", + dataset_version="v1", + dataset_root=dataset_root, + ) + validate_no_overlap(splits) + + assert set(splits) == { + SplitType.TRAIN, + SplitType.VALIDATION, + SplitType.TEST, + } + assert splits[SplitType.TEST].split_hash is not None + assert splits[SplitType.TRAIN].verify_no_leakage(splits[SplitType.TEST]) + + manifest_path = tmp_path / "split.json" + save_split_manifest(splits, manifest_path) + loaded = load_split_manifest(manifest_path) + + assert loaded[SplitType.TEST].dataset_name == "demo" + assert loaded[SplitType.TEST].split_hash == splits[SplitType.TEST].split_hash + assert len(collect_supported_images(dataset_root)) == 6 + + +def test_validate_no_overlap_rejects_duplicate_images(tmp_path: Path): + img = tmp_path / "x.png" + img.write_bytes(b"fake") + train_meta = SplitMetadata( + split_type=SplitType.TRAIN, + dataset_name="demo", + dataset_version="v1", + split_method="manual", + split_ratio={"train": 1.0}, + random_seed=None, + dataset_root=str(tmp_path), + image_ids=[str(img)], + ) + test_meta = SplitMetadata( + split_type=SplitType.TEST, + dataset_name="demo", + dataset_version="v1", + split_method="manual", + split_ratio={"test": 1.0}, + random_seed=None, + dataset_root=str(tmp_path), + image_ids=[str(img)], + ) + + with pytest.raises(ValueError): + validate_no_overlap({SplitType.TRAIN: train_meta, SplitType.TEST: test_meta})