Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 206 additions & 2 deletions src/auralock/core/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from __future__ import annotations

import io
from collections.abc import Callable, Iterable
from functools import lru_cache

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models import ResNet18_Weights, resnet18
from torchvision.models.feature_extraction import create_feature_extractor

Expand Down Expand Up @@ -214,20 +217,221 @@ def resize_restore(images: torch.Tensor, scale: float = 0.75) -> torch.Tensor:
)


def jpeg_compress_decompress(
images: torch.Tensor,
quality: int = 85,
) -> torch.Tensor:
"""Apply JPEG compression and decompression to test robustness against compression artifacts."""
if not 1 <= quality <= 100:
raise ValueError("quality must be between 1 and 100.")

batch_size = images.shape[0]
device = images.device
results = []

for i in range(batch_size):
# Convert tensor to PIL Image (0-255 range)
img_np = (
(images[i].detach().permute(1, 2, 0).cpu().numpy() * 255)
.clip(0, 255)
.astype(np.uint8)
)
img_pil = Image.fromarray(img_np)

# JPEG compress in memory
buffer = io.BytesIO()
img_pil.save(buffer, format="JPEG", quality=quality)
buffer.seek(0)

# Decompress
img_pil = Image.open(buffer)
img_np = np.array(img_pil).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1)
results.append(img_tensor)

return torch.stack(results).to(device)


def center_crop_and_resize(
images: torch.Tensor,
crop_ratio: float = 0.9,
) -> torch.Tensor:
"""Center crop to crop_ratio and resize back to original size."""
if not 0.0 < crop_ratio <= 1.0:
raise ValueError("crop_ratio must be in (0, 1].")

_, _, h, w = images.shape
crop_h = max(1, int(h * crop_ratio))
crop_w = max(1, int(w * crop_ratio))

start_h = (h - crop_h) // 2
start_w = (w - crop_w) // 2

cropped = images[:, :, start_h : start_h + crop_h, start_w : start_w + crop_w]

return F.interpolate(
cropped,
size=(h, w),
mode="bilinear",
align_corners=False,
antialias=True,
)


def random_crop_and_resize(
images: torch.Tensor,
crop_ratio: float = 0.9,
) -> torch.Tensor:
"""Random crop to crop_ratio and resize back to original size."""
if not 0.0 < crop_ratio <= 1.0:
raise ValueError("crop_ratio must be in (0, 1].")

_, _, h, w = images.shape
crop_h = max(1, int(h * crop_ratio))
crop_w = max(1, int(w * crop_ratio))

# Random start position
max_start_h = h - crop_h
max_start_w = w - crop_w
start_h = torch.randint(0, max_start_h + 1, (1,)).item() if max_start_h > 0 else 0
start_w = torch.randint(0, max_start_w + 1, (1,)).item() if max_start_w > 0 else 0

cropped = images[:, :, start_h : start_h + crop_h, start_w : start_w + crop_w]

return F.interpolate(
cropped,
size=(h, w),
mode="bilinear",
align_corners=False,
antialias=True,
)


def color_jitter(
images: torch.Tensor,
brightness: float = 0.1,
contrast: float = 0.1,
saturation: float = 0.1,
hue: float = 0.05,
) -> torch.Tensor:
"""Apply color jitter transformations (brightness, contrast, saturation, hue)."""
if any(x < 0 for x in [brightness, contrast, saturation, hue]):
raise ValueError("All jitter parameters must be non-negative.")

result = images.clone()

# Apply brightness jitter
if brightness > 0:
brightness_factor = (
1.0 + torch.empty(1).uniform_(-brightness, brightness).item()
)
result = (result * brightness_factor).clamp(0.0, 1.0)

# Apply contrast jitter
if contrast > 0:
contrast_factor = 1.0 + torch.empty(1).uniform_(-contrast, contrast).item()
mean = result.mean(dim=(2, 3), keepdim=True)
result = ((result - mean) * contrast_factor + mean).clamp(0.0, 1.0)

# Apply saturation jitter (convert to grayscale and interpolate)
if saturation > 0:
saturation_factor = (
1.0 + torch.empty(1).uniform_(-saturation, saturation).item()
)
# Convert to grayscale using standard weights
gray = (
0.299 * result[:, 0:1, :, :]
+ 0.587 * result[:, 1:2, :, :]
+ 0.114 * result[:, 2:3, :, :]
)
gray = gray.expand_as(result)
result = (saturation_factor * result + (1 - saturation_factor) * gray).clamp(
0.0, 1.0
)

# Apply hue jitter (convert to HSV and back)
if hue > 0:
# Simple approximation: shift each channel differently
hue_factor = torch.empty(1).uniform_(-hue, hue).item()
# Rotate RGB channels
if abs(hue_factor) > 0.01:
# This is a simplified hue shift - not perfect but preserves differentiability
shift = int(hue_factor * 255) % 3
if shift != 0:
result = torch.roll(result, shifts=shift, dims=1)

return result


def add_gaussian_noise(
images: torch.Tensor,
std: float = 0.01,
) -> torch.Tensor:
"""Add Gaussian noise to images for robustness testing."""
if std < 0:
raise ValueError("std must be non-negative.")

noise = torch.randn_like(images) * std
return (images + noise).clamp(0.0, 1.0)


def build_style_transform_suite() -> tuple[tuple[str, StyleTransform], ...]:
"""Transforms used for both robust optimization and benchmark reporting."""
"""Transforms used for both robust optimization and benchmark reporting.

This suite now includes critical preprocessing transformations that real-world
mimicry pipelines actually use, including JPEG compression, cropping, and
color augmentations.
"""

def identity(images: torch.Tensor) -> torch.Tensor:
return images

return (
# Baseline
("identity", identity),
# Existing transforms (kept for backward compatibility)
(
"gaussian_blur",
"gaussian_blur_mild",
lambda images: gaussian_blur(images, kernel_size=5, sigma=1.0),
),
("resize_restore_75", lambda images: resize_restore(images, scale=0.75)),
("resize_restore_50", lambda images: resize_restore(images, scale=0.5)),
# NEW: JPEG Compression (critical - most effective purification)
(
"jpeg_quality_95",
lambda images: jpeg_compress_decompress(images, quality=95),
),
(
"jpeg_quality_85",
lambda images: jpeg_compress_decompress(images, quality=85),
),
(
"jpeg_quality_75",
lambda images: jpeg_compress_decompress(images, quality=75),
),
# NEW: Cropping (standard data augmentation in training pipelines)
(
"center_crop_90",
lambda images: center_crop_and_resize(images, crop_ratio=0.9),
),
(
"center_crop_80",
lambda images: center_crop_and_resize(images, crop_ratio=0.8),
),
# NEW: Stronger blur variants
(
"gaussian_blur_medium",
lambda images: gaussian_blur(images, kernel_size=7, sigma=2.0),
),
# NEW: Color augmentation (common in training)
(
"color_jitter_mild",
lambda images: color_jitter(
images, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05
),
),
# NEW: Noise injection (regularization during training)
("gaussian_noise_small", lambda images: add_gaussian_noise(images, std=0.01)),
)


Expand Down
Loading
Loading