Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/auralock/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,26 @@
build_lora_train_command,
evaluate_lora_preflight,
)
from auralock.benchmarks.splits import (
SplitMetadata,
SplitType,
compute_split_hash,
create_random_split,
load_split_manifest,
save_split_manifest,
validate_split_manifest,
warn_non_test_split,
)

__all__ = [
"SplitMetadata",
"SplitType",
"compute_split_hash",
"create_random_split",
"load_split_manifest",
"save_split_manifest",
"validate_split_manifest",
"warn_non_test_split",
"DEFAULT_ANTI_DREAMBOOTH_CLASS_PROMPT",
"DEFAULT_ANTI_DREAMBOOTH_INFER_SCRIPT",
"DEFAULT_ANTI_DREAMBOOTH_INSTANCE_PROMPT",
Expand Down
4 changes: 3 additions & 1 deletion src/auralock/benchmarks/antidreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def run(

notes = [
"This benchmark follows the Anti-DreamBooth paper-style set_A/set_B/set_C split.",
"set_A is retained as a clean reference split, set_B is treated as the published split, and set_C is preserved as holdout metadata.",
"set_A is retained as a clean reference split, set_B is treated as the published (training) split, "
"and set_C is the held-out validation split used to measure out-of-sample protection effectiveness.",
"Evaluate mimicry success on set_C (holdout) images—never on set_B—to avoid in-sample bias.",
"AuraLock still uses its own protection pipeline; this workflow is a benchmark alignment layer, not an ASPL/FSMG reproduction.",
]

Expand Down
27 changes: 27 additions & 0 deletions src/auralock/benchmarks/splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Dataset split methodology utilities for reproducible benchmark evaluation.

This module re-exports from :mod:`auralock.core.splits` for convenience.
The canonical implementation lives in the core package to avoid circular imports.
"""

from auralock.core.splits import (
SplitMetadata,
SplitType,
compute_split_hash,
create_random_split,
load_split_manifest,
save_split_manifest,
validate_split_manifest,
warn_non_test_split,
)

__all__ = [
"SplitMetadata",
"SplitType",
"compute_split_hash",
"create_random_split",
"load_split_manifest",
"save_split_manifest",
"validate_split_manifest",
"warn_non_test_split",
]
184 changes: 184 additions & 0 deletions src/auralock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
AntiDreamBoothSubjectBenchmarkHarness,
DockerLoraBenchmarkConfig,
LoraBenchmarkHarness,
SplitMetadata,
SplitType,
build_docker_lora_benchmark_plan,
create_random_split,
load_split_manifest,
save_split_manifest,
validate_split_manifest,
)
from auralock.core.image import save_image
from auralock.services import (
Expand All @@ -42,6 +48,13 @@
)
console = Console()

split_app = typer.Typer(
name="split",
help="Dataset split management for reproducible benchmark evaluation.",
add_completion=False,
)
app.add_typer(split_app, name="split")


def _to_builtin(value: Any) -> Any:
"""Convert report payloads into JSON-friendly values."""
Expand Down Expand Up @@ -634,6 +647,18 @@ def benchmark(
recursive: bool = typer.Option(
False, "--recursive", help="Benchmark nested directories recursively"
),
split_manifest: Path | None = typer.Option(
None,
"--split-manifest",
help="Path to a JSON split manifest (from 'auralock split create'). "
"When provided, only images in the declared split are benchmarked.",
),
split_type: str = typer.Option(
"test",
"--split-type",
help="Split to use from the manifest: train, val, test, or dev. "
"Using a non-test split emits a bias warning.",
),
report: Path | None = typer.Option(
None,
"--report",
Expand All @@ -652,6 +677,42 @@ def benchmark(
console.print("[red]Error:[/red] At least one profile is required.")
raise typer.Exit(1)

resolved_split_metadata: SplitMetadata | None = None
if split_manifest is not None:
try:
all_splits = load_split_manifest(split_manifest)
except FileNotFoundError as exc:
console.print(f"[red]Error:[/red] {exc}")
raise typer.Exit(1) from exc
except (json.JSONDecodeError, KeyError, ValueError) as exc:
console.print(f"[red]Error:[/red] Could not parse split manifest: {exc}")
raise typer.Exit(1) from exc

try:
requested_type = SplitType(split_type)
except ValueError as exc:
valid = ", ".join(t.value for t in SplitType)
console.print(
f"[red]Error:[/red] Invalid --split-type '{split_type}'. "
f"Valid values: {valid}"
)
raise typer.Exit(1) from exc

if requested_type not in all_splits:
console.print(
f"[red]Error:[/red] Split type '{split_type}' not found in manifest. "
f"Available: {', '.join(k.value for k in all_splits)}"
)
raise typer.Exit(1)

resolved_split_metadata = all_splits[requested_type]
if requested_type != SplitType.TEST:
console.print(
f"[yellow]⚠ WARNING:[/yellow] Benchmarking on '{split_type}' split. "
"Results may be overfit to this split. "
"Use --split-type test for final evaluation."
)

with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
Expand All @@ -668,11 +729,13 @@ def benchmark(
input_path,
profiles=profile_names,
recursive=recursive,
split_metadata=resolved_split_metadata,
)
else:
summary = service.benchmark_file(
input_path,
profiles=profile_names,
split_metadata=resolved_split_metadata,
)
except ValueError as exc:
console.print(f"[red]Error:[/red] {exc}")
Expand All @@ -681,6 +744,13 @@ def benchmark(

console.print(_render_profile_summary_table(summary.profile_summaries))

if resolved_split_metadata is not None:
console.print(
f"[dim]Split: {resolved_split_metadata.split_type.value} | "
f"Images: {len(resolved_split_metadata.image_ids)} | "
f"Hash: {resolved_split_metadata.split_hash}[/dim]"
)

# Print prominent warning about protection scores in benchmark summary
from auralock.core.metrics import (
PROTECTION_SCORE_DISCLAIMER,
Expand Down Expand Up @@ -1164,6 +1234,120 @@ def benchmark_lora_docker(
raise typer.Exit(exc.returncode or 1) from exc


@split_app.command("create")
def split_create(
dataset_dir: Path = typer.Argument(
..., help="Directory containing images to split"
),
output: Path = typer.Option(
..., "--output", help="Path to write the JSON split manifest"
),
dataset_name: str = typer.Option(
"dataset", "--dataset-name", help="Human-readable dataset name"
),
dataset_version: str = typer.Option(
"1.0", "--dataset-version", help="Dataset version string"
),
train_ratio: float = typer.Option(
0.7, "--train-ratio", help="Fraction of images for the train split"
),
val_ratio: float = typer.Option(
0.15, "--val-ratio", help="Fraction of images for the validation split"
),
test_ratio: float = typer.Option(
0.15, "--test-ratio", help="Fraction of images for the test split"
),
seed: int = typer.Option(42, "--seed", help="Random seed for reproducibility"),
recursive: bool = typer.Option(
False, "--recursive", help="Scan sub-directories for images"
),
) -> None:
"""Create a train/val/test split manifest from a dataset directory."""
from auralock.core.image import SUPPORTED_EXTENSIONS

if not dataset_dir.exists() or not dataset_dir.is_dir():
console.print(f"[red]Error:[/red] Dataset directory not found: {dataset_dir}")
raise typer.Exit(1)

iterator = dataset_dir.rglob("*") if recursive else dataset_dir.glob("*")
image_paths = [
p
for p in sorted(iterator)
if p.is_file() and p.suffix.lower() in SUPPORTED_EXTENSIONS
]
if not image_paths:
console.print(f"[red]Error:[/red] No supported images found in: {dataset_dir}")
raise typer.Exit(1)

try:
splits = create_random_split(
image_paths,
dataset_name=dataset_name,
dataset_version=dataset_version,
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
random_seed=seed,
)
except ValueError as exc:
console.print(f"[red]Error:[/red] {exc}")
raise typer.Exit(1) from exc

save_split_manifest(splits, output)

table = Table(title="Split Manifest Summary")
table.add_column("Split", style="cyan")
table.add_column("Images", style="green")
table.add_column("Ratio", style="yellow")
table.add_column("Hash", style="magenta")
for split_type, split_meta in splits.items():
table.add_row(
split_type.value,
str(len(split_meta.image_ids)),
f"{split_meta.split_ratio.get(split_type.value, 0):.2f}",
split_meta.split_hash,
)
console.print(table)
console.print(f"[green]Split manifest saved to:[/green] {output}")


@split_app.command("validate")
def split_validate(
manifest: Path = typer.Argument(..., help="Path to the JSON split manifest"),
) -> None:
"""Validate a split manifest for data leakage and hash integrity."""
try:
splits = load_split_manifest(manifest)
except FileNotFoundError as exc:
console.print(f"[red]Error:[/red] {exc}")
raise typer.Exit(1) from exc
except (json.JSONDecodeError, KeyError, ValueError) as exc:
console.print(f"[red]Error:[/red] Could not parse split manifest: {exc}")
raise typer.Exit(1) from exc

issues = validate_split_manifest(splits)

table = Table(title="Split Validation")
table.add_column("Split", style="cyan")
table.add_column("Images", style="green")
table.add_column("Hash", style="magenta")
for split_type, split_meta in splits.items():
table.add_row(
split_type.value,
str(len(split_meta.image_ids)),
split_meta.split_hash,
)
console.print(table)

if issues:
for issue in issues:
console.print(f"[red]Issue:[/red] {issue}")
console.print(f"[red]{len(issues)} validation issue(s) found.[/red]")
raise typer.Exit(1)
else:
console.print("[green]✓ No issues found. Split manifest is valid.[/green]")


@app.command()
def demo() -> None:
"""Run a quick demo using a synthetic image."""
Expand Down
Loading