-
Notifications
You must be signed in to change notification settings - Fork 16
feat: implementing Kernel Inception Metric #65 : #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
2af982a
f23efa4
c4bea1b
85f88a8
87b736a
f13d26e
c1c30be
24b16d0
88f9ec6
6bdd189
4cb3d65
f9d5954
965eab5
075e68d
878876c
d40ee0c
d3df741
8969ce9
74cf343
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,4 +7,6 @@ keras-rs | |
| pytest | ||
| rouge-score | ||
| scikit-learn | ||
| tensorflow | ||
| tensorflow | ||
| torchmetrics | ||
| torch-fidelity | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,20 @@ | |
|
|
||
| """A collection of different metrics for image models.""" | ||
|
|
||
| import jax.numpy as jnp | ||
| from jax import random | ||
| import flax | ||
| import jax | ||
| from jax import lax | ||
| import jax.numpy as jnp | ||
| from clu import metrics as clu_metrics | ||
| from metrax import base | ||
| import numpy as np | ||
| from PIL import Image | ||
|
|
||
| KID_DEFAULT_SUBSETS = 100 | ||
| KID_DEFAULT_SUBSET_SIZE = 1000 | ||
| KID_DEFAULT_DEGREE = 3 | ||
| KID_DEFAULT_GAMMA = None | ||
| KID_DEFAULT_COEF = 1.0 | ||
|
|
||
|
|
||
| def _gaussian_kernel1d(sigma, radius): | ||
|
|
@@ -53,6 +62,181 @@ def _gaussian_kernel1d(sigma, radius): | |
| return phi_x | ||
|
|
||
|
|
||
| def polynomial_kernel(x: jax.Array, y: jax.Array, degree: int, gamma: float, coef: float) -> jax.Array: | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Compute the polynomial kernel between two sets of features. | ||
| Args: | ||
| x: First set of features. | ||
| y: Another set of features to be computed with. | ||
| degree: Degree of the polynomial kernel. | ||
| gamma: Kernel coefficient for the polynomial kernel. If None, uses 1 / x.shape[1]. | ||
| coef: Independent term in the polynomial kernel. | ||
| Returns: | ||
| Polynomial kernel value of Array type. | ||
| """ | ||
| if gamma is None: | ||
| gamma = 1.0 / x.shape[1] | ||
| return (jnp.dot(x, y.T) * gamma + coef) ** degree | ||
|
|
||
|
|
||
| def random_images(seed, n): | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Generate n random RGB images as numpy arrays in (N, 3, 299, 299) format using PIL.Image. | ||
| Args: | ||
| seed: Random seed for reproducibility. | ||
| n: Number of images to generate. | ||
| Returns: | ||
| images: numpy array of shape (n, 3, 299, 299), dtype uint8 | ||
| """ | ||
| rng = np.random.RandomState(seed) | ||
| images = [] | ||
| for _ in range(n): | ||
| # Generate a random (299, 299, 3) uint8 array | ||
| arr = rng.randint(0, 256, size=(299, 299, 3), dtype=np.uint8) | ||
| # Convert to PIL Image and back to numpy to ensure valid image | ||
| img = Image.fromarray(arr, mode='RGB') | ||
| arr_pil = np.array(img) | ||
| # Transpose to (3, 299, 299) as required by KID/torchmetrics | ||
| arr_pil = arr_pil.transpose(2, 0, 1) | ||
| images.append(arr_pil) | ||
| return np.stack(images, axis=0).astype(np.uint8) | ||
|
|
||
|
|
||
| @flax.struct.dataclass | ||
| class KernelInceptionDistance(base.Average): | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| r"""Computes Kernel Inception Distance (KID) for asses quality of generated images. | ||
| KID is a metric used to evaluate the quality of generated images by comparing | ||
| the distribution of generated images to the distribution of real images. | ||
| It is based on the Inception Score (IS) and uses a kernelized version of the | ||
| Maximum Mean Discrepancy (MMD) to measure the distance between two | ||
| distributions. | ||
|
|
||
| The KID is computed as follows: | ||
|
|
||
| .. math:: | ||
| KID = MMD(f_{real}, f_{fake})^2 | ||
|
|
||
| Where :math:`MMD` is the maximum mean discrepancy and :math:`I_{real}, I_{fake}` are extracted features | ||
| from real and fake images, see `kid ref1`_ for more details. In particular, calculating the MMD requires the | ||
| evaluation of a polynomial kernel function :math:`k`. | ||
|
|
||
| .. math:: | ||
| k(x,y) = (\gamma * x^T y + coef)^{degree} | ||
|
|
||
| Args: | ||
| subsets: Number of subsets to use for KID calculation. | ||
| subset_size: Number of samples in each subset. | ||
| degree: Degree of the polynomial kernel. | ||
| gamma: Kernel coefficient for the polynomial kernel. | ||
| coef: Independent term in the polynomial kernel. | ||
| """ | ||
|
|
||
| subsets: int = KID_DEFAULT_SUBSETS | ||
| subset_size: int = KID_DEFAULT_SUBSET_SIZE | ||
| degree: int = KID_DEFAULT_DEGREE | ||
| gamma: float = KID_DEFAULT_GAMMA | ||
| coef: float = KID_DEFAULT_COEF | ||
|
|
||
| @classmethod | ||
| def from_model_output( | ||
| cls, | ||
| real_features: jax.Array, | ||
| fake_features: jax.Array, | ||
| subsets: int = KID_DEFAULT_SUBSETS, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we not define member variables and just replace with default values here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. It was just in case if some |
||
| subset_size: int = KID_DEFAULT_SUBSET_SIZE, | ||
| degree: int = KID_DEFAULT_DEGREE, | ||
| gamma: float = KID_DEFAULT_GAMMA, | ||
| coef: float = KID_DEFAULT_COEF, | ||
| ): | ||
| # checks for the valid inputs | ||
| if subsets <= 0 or subset_size <= 0 or degree <= 0 or (gamma is not None and gamma <= 0) or coef <= 0: | ||
| raise ValueError("All parameters must be positive and non-zero.") | ||
| # Compute KID for this batch | ||
| if real_features.shape[0] < subset_size or fake_features.shape[0] < subset_size: | ||
| raise ValueError("Subset size must be smaller than the number of samples.") | ||
| master_key = random.PRNGKey(42) | ||
| kid_scores = [] | ||
| for i in range(subsets): | ||
| key_real, key_fake = random.split(random.fold_in(master_key, i)) | ||
| real_indices = random.choice(key_real, real_features.shape[0], (subset_size,), replace=False) | ||
| fake_indices = random.choice(key_fake, fake_features.shape[0], (subset_size,), replace=False) | ||
| f_real_subset = real_features[real_indices] | ||
| f_fake_subset = fake_features[fake_indices] | ||
| kid = cls.__compute_mmd_static(f_real_subset, f_fake_subset, degree, gamma, coef) | ||
| kid_scores.append(kid) | ||
| kid_mean = jnp.mean(jnp.array(kid_scores)) | ||
| # Accumulate sum and count for averaging | ||
| return cls( | ||
| total=kid_mean, | ||
| count=1.0, | ||
| subsets=subsets, | ||
| subset_size=subset_size, | ||
| degree=degree, | ||
| gamma=gamma, | ||
| coef=coef, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def empty(cls) -> "KernelInceptionDistance": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we are inheriting base.Average, we shouldn't need to implement empty |
||
| """ | ||
| Create an empty instance of KernelInceptionDistance. | ||
| """ | ||
| return cls( | ||
| total=0.0, | ||
| count=0.0, | ||
| subsets=KID_DEFAULT_SUBSETS, | ||
| subset_size=KID_DEFAULT_SUBSET_SIZE, | ||
| degree=KID_DEFAULT_DEGREE, | ||
| gamma=KID_DEFAULT_GAMMA, | ||
| coef=KID_DEFAULT_COEF, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| def __compute_mmd_static(f_real: jax.Array, f_fake: jax.Array, degree: int, gamma: float, coef: float) -> float: | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| k_11 = polynomial_kernel(f_real, f_real, degree, gamma, coef) | ||
| k_22 = polynomial_kernel(f_fake, f_fake, degree, gamma, coef) | ||
| k_12 = polynomial_kernel(f_real, f_fake, degree, gamma, coef) | ||
|
|
||
| m = f_real.shape[0] | ||
| diag_x = jnp.diag(k_11) | ||
| diag_y = jnp.diag(k_22) | ||
|
|
||
| kt_xx_sum = jnp.sum(k_11, axis=-1) - diag_x | ||
| kt_yy_sum = jnp.sum(k_22, axis=-1) - diag_y | ||
| k_xy_sum = jnp.sum(k_12, axis=0) | ||
|
|
||
| value = (jnp.sum(kt_xx_sum) + jnp.sum(kt_yy_sum)) / (m * (m - 1)) | ||
| value -= 2 * jnp.sum(k_xy_sum) / (m**2) | ||
| return value | ||
|
|
||
|
|
||
| def compute(self) -> jax.Array: | ||
|
dhruvmalik007 marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Compute the average KID value from accumulated batches. | ||
| Always returns a scalar (0-dim array or float). | ||
| """ | ||
| result = base.divide_no_nan(self.total, self.count) | ||
| # If result is a 0-dim array, convert to float for easier downstream use | ||
| if hasattr(result, 'shape') and result.shape == (): | ||
| return float(result) | ||
| return result | ||
|
|
||
|
|
||
| def merge(self, other: "KernelInceptionDistance") -> "KernelInceptionDistance": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we are inheriting base.Average, we shouldn't need to implement compute and merge. |
||
| """ | ||
| Merge two KernelInceptionDistance instances by summing totals and counts. | ||
| """ | ||
| return type(self)( | ||
| total=self.total + other.total, | ||
| count=self.count + other.count, | ||
| subsets=self.subsets, | ||
| subset_size=self.subset_size, | ||
| degree=self.degree, | ||
| gamma=self.gamma, | ||
| coef=self.coef, | ||
| ) | ||
|
|
||
|
|
||
| @flax.struct.dataclass | ||
| class SSIM(base.Average): | ||
| r"""SSIM (Structural Similarity Index Measure) Metric. | ||
|
|
@@ -360,3 +544,4 @@ def from_model_output( # type: ignore[override] | |
| k2=k2, | ||
| ) | ||
| return super().from_model_output(values=batch_ssim_values) | ||
|
|
||
|
jshin1394 marked this conversation as resolved.
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove this line since image_metrics was added as part of previous PR in line 17.