From c585e5251595027575fecab2f1eac86f7fed1d67 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Tue, 26 May 2026 08:30:24 -0500 Subject: [PATCH] Add sepreformer example --- docs_nnx/guides/audio.ipynb | 801 +++++++++++++++++++++++++++++++++ docs_nnx/guides/audio.md | 579 ++++++++++++++++++++++++ flax/nnx/__init__.py | 2 + flax/nnx/nn/attention.py | 100 ++++ tests/nnx/nn/attention_test.py | 36 ++ 5 files changed, 1518 insertions(+) create mode 100644 docs_nnx/guides/audio.ipynb create mode 100644 docs_nnx/guides/audio.md diff --git a/docs_nnx/guides/audio.ipynb b/docs_nnx/guides/audio.ipynb new file mode 100644 index 000000000..8adc1301f --- /dev/null +++ b/docs_nnx/guides/audio.ipynb @@ -0,0 +1,801 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1794712e", + "metadata": {}, + "source": [ + "# Building a Choral Source Separator with SepReformer in JAX\n", + "\n", + "This tutorial demonstrates how to perform audio source separation using the `SepReformer`\n", + "architecture. We'll use `flax.nnx` for the neural network, with `beartype` for runtime type checking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5b2848", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import nnx\n", + "from jaxtyping import Float, Array, jaxtyped\n", + "from beartype import beartype\n", + "import soundfile as sf" + ] + }, + { + "cell_type": "markdown", + "id": "e011d75a", + "metadata": {}, + "source": [ + "## The Task\n", + "\n", + "Given a mono mixture waveform $x \\in \\mathbb{R}^T$, produce $N$ separated stem\n", + "waveforms $\\hat{s}_1, \\ldots, \\hat{s}_N \\in \\mathbb{R}^T$ such that\n", + "$\\sum_n \\hat{s}_n \\approx x$ and each $\\hat{s}_n$ matches one isolated voice\n", + "track.\n", + "\n", + "We use the **JaCappella** corpus (35 a cappella songs) via Hugging Face. Each\n", + "song has 5 isolated stems — `lead_vocal`, `soprano`, `alto`, `tenor`, `bass` —\n", + "and the mixture is their sum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e2c106f", + "metadata": {}, + "outputs": [], + "source": [ + "import librosa\n", + "from pathlib import Path\n", + "\n", + "STEM_NAMES = (\"lead_vocal\", \"soprano\", \"alto\", \"tenor\", \"bass\")\n", + "SAMPLE_RATE = 44100\n", + "\n", + "class JaCappellaDataset:\n", + " def __init__(self, root) -> None:\n", + " self.root = Path(root)\n", + " self.sample_rate = 44100\n", + " self.songs = sorted(\n", + " song\n", + " for genre in self.root.iterdir()\n", + " if genre.is_dir() and not genre.name.startswith(\".\")\n", + " for song in genre.iterdir()\n", + " if song.is_dir() and not song.name.startswith(\".\"))\n", + "\n", + " @jaxtyped(typechecker=beartype)\n", + " def _load_wav(self, path: Path) -> Float[np.ndarray, \"T\"]:\n", + " \"\"\"Load a wav file, resample if needed, return mono float32.\"\"\"\n", + " audio, sr = sf.read(path, dtype=\"float32\", always_2d=True)\n", + " audio = audio[:, 0] # take first channel if stereo\n", + " if sr != self.sample_rate:\n", + " audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)\n", + " return audio\n", + "\n", + " @jaxtyped(typechecker=beartype)\n", + " def _load_stems(self, song_dir: Path) -> Float[np.ndarray, \"N T\"]:\n", + " \"\"\"Load all available stems for a song.\"\"\"\n", + " stems = []\n", + " for i, name in enumerate(STEM_NAMES):\n", + " path = song_dir / f\"{name}.wav\"\n", + " stems.append(self._load_wav(path) if path.exists() else np.zeros(0, dtype=np.float32))\n", + " max_len = max(len(s) for s in stems)\n", + " result = np.zeros((5, max_len), dtype=np.float32)\n", + " for i, s in enumerate(stems):\n", + " result[i, :len(s)] = s\n", + " return result\n", + "\n", + " @jaxtyped(typechecker=beartype)\n", + " def __getitem__(self, idx: int) -> tuple[Float[np.ndarray, \"T\"], Float[np.ndarray, \"N T\"]]:\n", + " \"\"\"Load full song.\"\"\"\n", + " song_dir = self.songs[idx]\n", + " stems = self._load_stems(song_dir)\n", + " return stems.sum(axis=0), stems\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.songs)\n", + "\n", + "dataset = JaCappellaDataset(\"/space/samanklesaria/data/jacappella\")" + ] + }, + { + "cell_type": "markdown", + "id": "a811807f", + "metadata": {}, + "source": [ + "We'll feed the dataset to our model using the `grain` library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c66dea25", + "metadata": {}, + "outputs": [], + "source": [ + "import grain\n", + "\n", + "SEG_SAMPLES = SAMPLE_RATE * 2 # 2-second segments\n", + "BATCH_SIZE = 1\n", + "\n", + "def extract_segment(item, rng):\n", + " mixture, stems = item\n", + " T = mixture.shape[0]\n", + " if T <= SEG_SAMPLES:\n", + " pad = SEG_SAMPLES - T\n", + " return np.pad(mixture, (0, pad)), np.pad(stems, ((0, 0), (0, pad)))\n", + " start = rng.integers(0, T - SEG_SAMPLES)\n", + " return mixture[start:start + SEG_SAMPLES], stems[:, start:start + SEG_SAMPLES]\n", + "\n", + "def batch_to_jax(items):\n", + " mixtures = jnp.array(np.stack([m for m, _ in items])) # (B, T)\n", + " stems = jnp.array(np.stack([s for _, s in items])) # (B, N, T)\n", + " return mixtures, stems\n", + "\n", + "loader = (grain.MapDataset.source(dataset)\n", + " .seed(0).shuffle()\n", + " .random_map(extract_segment)\n", + " .batch(BATCH_SIZE, drop_remainder=True, batch_fn=batch_to_jax))" + ] + }, + { + "cell_type": "markdown", + "id": "0200fb2e", + "metadata": {}, + "source": [ + "To make sure the data is loading correctly, we can sample a batch and log it to Tensorboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62cdbc65", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorboardX import SummaryWriter\n", + "\n", + "mixture, stems = next(iter(loader))\n", + "mixture = np.array(mixture[0]) # (T,)\n", + "stems = np.array(stems[0]) # (N, T)\n", + "writer = SummaryWriter(\"samples\")\n", + "peak = np.max(np.abs(mixture))\n", + "scale = 0.99 / peak if peak > 0 else 1.0\n", + "writer.add_audio(\n", + " f\"mixture\",\n", + " mixture * scale,\n", + " sample_rate=dataset.sample_rate)\n", + "for n in range(stems.shape[0]):\n", + " writer.add_audio(\n", + " f\"stem/{n}\",\n", + " stems[n] * scale,\n", + " sample_rate=dataset.sample_rate)\n", + "writer.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0588c381", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(open(\"mixture.html\").read()))\n", + "display(HTML(open(\"stem.html\").read()))" + ] + }, + { + "cell_type": "markdown", + "id": "01d25e37", + "metadata": {}, + "source": [ + "## Architecture" + ] + }, + { + "cell_type": "markdown", + "id": "81382b61", + "metadata": {}, + "source": [ + "At a high level:\n", + "- The input waveform is encoded to a latent space using convolutional and transformer layers.\n", + "- The result gets split into separate pieces for each voice part\n", + "- Each piece is decoded back to a waveform by the same stack of transformer and convolutional layers.\n", + "\n", + "$$x \\xrightarrow{\\text{Conv}} h \\xrightarrow{\\text{Enc blocks}} h \\xrightarrow{\\text{Split}} \\{h_n\\} \\xrightarrow{\\text{Dec blocks}} \\xrightarrow{\\text{ConvT}} \\{\\hat{s}_n\\}$$\n", + "\n", + "## Convolutional Layers: Waveform → Latent Frames and Back\n", + "\n", + "A strided 1-D convolution converts the raw waveform into a sequence of latent\n", + "frames. With kernel $K$ and stride $S$:\n", + "\n", + "$$L = \\left\\lfloor \\frac{T - K}{S} \\right\\rfloor + 1$$\n", + "\n", + "The encoder applies a GELU after the convolution:\n", + "\n", + "$$h = \\text{GELU}(W_\\text{enc} * x), \\quad h \\in \\mathbb{R}^{L \\times C}$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3bcf612", + "metadata": {}, + "outputs": [], + "source": [ + "class Encoder(nnx.Module):\n", + " def __init__(self, out_channels: int, *, rngs: nnx.Rngs):\n", + " self.conv = nnx.Conv(1, out_channels, kernel_size=(16,), strides=(8,),\n", + " padding='VALID', rngs=rngs)\n", + "\n", + " def __call__(self, x: Float[Array, \"B T\"]) -> Float[Array, \"B L C\"]:\n", + " h = self.conv(x[..., None]) # (B, T, 1) -> (B, L, C)\n", + " return jax.nn.gelu(h)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "456c95e1", + "metadata": {}, + "outputs": [], + "source": [ + "class Decoder(nnx.Module):\n", + " def __init__(self, in_channels: int, *, rngs: nnx.Rngs):\n", + " self.conv_t = nnx.ConvTranspose(in_channels, 1, kernel_size=(16,), strides=(8,),\n", + " padding='VALID', rngs=rngs)\n", + "\n", + " def __call__(self, h: Float[Array, \"B L C\"]) -> Float[Array, \"B T\"]:\n", + " out = self.conv_t(h) # (B, L, C) -> (B, T, 1)\n", + " return out[..., 0] # (B, T)" + ] + }, + { + "cell_type": "markdown", + "id": "e1310de1", + "metadata": {}, + "source": [ + "Default: $K=16$, $S=8$, $C=256$. At 44.1 kHz a 2-second clip becomes\n", + "$L \\approx 11{,}025$ frames." + ] + }, + { + "cell_type": "markdown", + "id": "908b4bd2", + "metadata": {}, + "source": [ + "## SNAKE Activation\n", + "\n", + "**SNAKE** (Liu et al., 2022) adds a learnable sinusoidal term that preserves the\n", + "periodic structure present in harmonic audio:\n", + "\n", + "$$\\text{SNAKE}(x; \\alpha) = x + \\frac{1}{\\alpha}\\sin^2(\\alpha x)$$\n", + "\n", + "$\\alpha$ is initialized to $\\mathbf{1}$ (small perturbation at startup) and\n", + "learned per channel. SNAKE is used inside every feed-forward sub-layer in the\n", + "transformer blocks. Because $\\alpha$ broadcasts along all leading dimensions,\n", + "`Snake` works with any number of batch axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e3f013", + "metadata": {}, + "outputs": [], + "source": [ + "class Snake(nnx.Module):\n", + " def __init__(self, features: int, *, rngs: nnx.Rngs):\n", + " self.alpha = nnx.Param(jnp.ones(features))\n", + "\n", + " def __call__(self, x: Float[Array, \"B F\"]) -> Float[Array, \"B F\"]:\n", + " a = self.alpha.value\n", + " return x + (1.0 / (a + 1e-6)) * jnp.sin(a * x) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fe5f041", + "metadata": {}, + "outputs": [], + "source": [ + "def feedforward(dim, ff_dim, rngs: nnx.Rngs):\n", + " return nnx.Sequential(\n", + " nnx.Linear(dim, ff_dim, rngs=rngs),\n", + " Snake(ff_dim, rngs=rngs),\n", + " nnx.Linear(ff_dim, dim, rngs=rngs))" + ] + }, + { + "cell_type": "markdown", + "id": "bcb20477", + "metadata": {}, + "source": [ + "## Rotary Positional Embeddings\n", + "\n", + "Rotary positional embeddings (RoPE) encode position by rotating pairs of\n", + "features through position-dependent angles. This gives the transformer\n", + "translation-equivariant relative position information without adding\n", + "explicit position tokens.\n", + "\n", + "Flax provides `nnx.RoPE`, which precomputes cosine and sine frequency\n", + "tables once and stores them as module state. To use it with\n", + "`nnx.MultiHeadAttention`, pass `nnx.dot_product_attention_with_rope`\n", + "(with the `rope` argument partially applied) as the `attention_fn`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1c966a8", + "metadata": {}, + "outputs": [], + "source": [ + "import functools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ed30a18", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlock(nnx.Module):\n", + " def __init__(\n", + " self, dim: int, num_heads: int, ff_dim: int,\n", + " max_seq_len: int = 2048, *, rngs: nnx.Rngs\n", + " ):\n", + " self.norm1 = nnx.LayerNorm(dim, rngs=rngs)\n", + " self.norm2 = nnx.LayerNorm(dim, rngs=rngs)\n", + " head_dim = dim // num_heads\n", + " rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_seq_len)\n", + " self.attn = nnx.MultiHeadAttention(\n", + " num_heads=num_heads,\n", + " in_features=dim,\n", + " qkv_features=dim,\n", + " attention_fn=functools.partial(\n", + " nnx.dot_product_attention_with_rope, rope=rope),\n", + " decode=False,\n", + " rngs=rngs,\n", + " )\n", + " self.ff = feedforward(dim, ff_dim, rngs=rngs)\n", + " self.scale1 = nnx.Param(jnp.full(dim, 1e-4))\n", + " self.scale2 = nnx.Param(jnp.full(dim, 1e-4))\n", + "\n", + " def __call__(self, x: Float[Array, \"B S D\"]) -> Float[Array, \"B S D\"]:\n", + " normed = self.norm1(x)\n", + " attn_out = self.attn(normed)\n", + " x = x + self.scale1[...] * attn_out\n", + " x = x + self.scale2[...] * self.ff(self.norm2(x))\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "68ded67a", + "metadata": {}, + "source": [ + "## Stacking Transformer Layers: The Dual-Path Approach\n", + "\n", + "Full self-attention over $L \\approx 11{,}000$ frames costs $O(L^2)$. The\n", + "dual-path trick (Luo & Mesgarani, 2020) splits this into two $O(L \\cdot K)$\n", + "passes:\n", + "\n", + "1. **Intra-chunk** — reshape to $(\\ldots, M, K, C)$; each of the $M$ chunks\n", + " attends within itself. Captures local patterns. Cost: $O(M \\cdot K^2)$.\n", + "2. **Inter-chunk** — swap to $(\\ldots, K, M, C)$; each time-slot attends\n", + " across all $M$ chunks. Propagates global pitch/rhythm. Cost: $O(K \\cdot M^2)$.\n", + "\n", + "Because `TransformerBlock` accepts `(B, S, D)`, the extra chunk axis\n", + "becomes just another batch dimension — no explicit `vmap` is needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64b35bbd", + "metadata": {}, + "outputs": [], + "source": [ + "class DualPathBlock(nnx.Module):\n", + " def __init__(\n", + " self,\n", + " dim: int,\n", + " num_heads: int,\n", + " ff_dim: int,\n", + " chunk_size: int = 64,\n", + " *,\n", + " rngs: nnx.Rngs,\n", + " ):\n", + " self.intra_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs)\n", + " self.inter_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs)\n", + " self.chunk_size = chunk_size\n", + "\n", + " def __call__(self, x: Float[Array, \"B L C\"]) -> Float[Array, \"B L C\"]:\n", + " B_shape, L, C = x.shape\n", + " K = self.chunk_size\n", + "\n", + " # Pad to multiple of chunk_size\n", + " pad_len = (K - L % K) % K\n", + " if pad_len > 0:\n", + " x = jnp.pad(x, [(0, 0)] * len(batch_shape) + [(0, pad_len), (0, 0)])\n", + "\n", + " L_padded = x.shape[-2]\n", + " M = L_padded // K\n", + "\n", + " # (B, L, C) -> (B, M, K, C)\n", + " chunks = x.reshape(B_shape, M, K, C)\n", + "\n", + " # Intra-chunk: TransformerBlock sees (B, M) as batch, K as seq\n", + " chunks = self.intra_block(chunks)\n", + "\n", + " # Inter-chunk: swap M <-> K, attend, swap back\n", + " inter = jnp.swapaxes(chunks, -3, -2) # (B, K, M, C)\n", + " inter = self.inter_block(inter)\n", + " chunks = jnp.swapaxes(inter, -3, -2) # (B, M, K, C)\n", + "\n", + " out = chunks.reshape(B_shape, L_padded, C)\n", + " return out[..., :L, :]" + ] + }, + { + "cell_type": "markdown", + "id": "8cbd5f34", + "metadata": {}, + "source": [ + "## Splitting into Speaker Streams\n", + "\n", + "After the shared encoder blocks, a `SplitLayer` expands $(\\ldots, L, C)$ into\n", + "$(\\ldots, N, L, C)$. Splitting here lets each of the $N$ reconstruction stacks\n", + "specialize on one speaker while sharing parameters — the subsequent\n", + "`DualPathBlock` and `Decoder` layers simply treat the new $N$ axis as an\n", + "additional batch dimension.\n", + "\n", + "A GLU gate first refines the shared features before expanding:\n", + "\n", + "$$g, v = \\text{split}(W_1 h), \\quad h' = \\sigma(g) \\odot v, \\quad \\text{streams} = W_2 h' \\;\\text{reshaped to}\\; (\\ldots, N, L, C)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c83d2e", + "metadata": {}, + "outputs": [], + "source": [ + "class SplitLayer(nnx.Module):\n", + " def __init__(self, dim: int, num_stems: int, *, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim * 2, rngs=rngs)\n", + " self.linear2 = nnx.Linear(dim, dim * num_stems, rngs=rngs)\n", + " self.num_stems = num_stems\n", + "\n", + " def __call__(self, x: Float[Array, \"B L C\"]) -> Float[Array, \"B N L C\"]:\n", + " h = self.linear1(x) # (B, L, 2C)\n", + " gate, val = jnp.split(h, 2, axis=-1)\n", + " h = jax.nn.sigmoid(gate) * val # (B, L, C)\n", + " h = self.linear2(h) # (B, L, N*C)\n", + " B_shape, L_dim, _ = h.shape\n", + " C = x.shape[-1]\n", + " h = h.reshape(B_shape, L_dim, self.num_stems, C) # (B, L, N, C)\n", + " return jnp.swapaxes(h, -3, -2) # (B, N, L, C)" + ] + }, + { + "cell_type": "markdown", + "id": "74259e0f", + "metadata": {}, + "source": [ + "## Full Forward Pass\n", + "\n", + "After `SplitLayer` produces $(\\ldots, N, L, C)$, the reconstruction blocks\n", + "and decoder see $(B, N)$ as batch dimensions. The entire forward pass\n", + "runs without any explicit `vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "924e4396", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class SepReformer(nnx.Module):\n", + " def __init__(self, *, rngs: nnx.Rngs):\n", + " num_sep_blocks = 2\n", + " num_rec_blocks = 2\n", + " dim = 256\n", + " num_heads = 8\n", + " ff_dim = 1024\n", + " chunk_size = 64\n", + "\n", + " self.encoder = Encoder(dim, rngs=rngs)\n", + " self.decoder = Decoder(dim, rngs=rngs)\n", + " self.split = SplitLayer(dim, 5, rngs=rngs)\n", + " self.sep_blocks = [\n", + " DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs)\n", + " for _ in range(num_sep_blocks)\n", + " ]\n", + " self.rec_blocks = [\n", + " DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs)\n", + " for _ in range(num_rec_blocks)\n", + " ]\n", + "\n", + " def __call__(self, x: Float[Array, \"B T\"]) -> Float[Array, \"B N T\"]:\n", + " h = self.encoder(x) # (B, L, C)\n", + " for block in self.sep_blocks:\n", + " h = block(h) # (B, L, C)\n", + " stems = self.split(h) # (B, N, L, C)\n", + " for block in self.rec_blocks:\n", + " stems = block(stems) # (B, N, L, C) — N is a batch dim\n", + " out = self.decoder(stems) # (B, N, T')\n", + " # trim / pad to original length\n", + " T = x.shape[-1]\n", + " if out.shape[-1] > T:\n", + " out = out[..., :T]\n", + " elif out.shape[-1] < T:\n", + " pad_width = [(0, 0)] * (out.ndim - 1) + [(0, T - out.shape[-1])]\n", + " out = jnp.pad(out, pad_width)\n", + " return out # (B, N, T)\n", + "\n", + "model = SepReformer(rngs=nnx.Rngs(0))" + ] + }, + { + "cell_type": "markdown", + "id": "ceb7d8e1", + "metadata": {}, + "source": [ + "## Loss Functions\n", + "\n", + "Supervising a source separator is not straightforward. A plain mean-squared\n", + "error (MSE) in the waveform domain penalises tiny timing offsets and global\n", + "loudness differences equally, so the model spends capacity chasing irrelevant\n", + "phase shifts rather than learning to separate voices. We instead use two\n", + "complementary objectives — one waveform-domain and one spectral — that together\n", + "give stable, perceptually meaningful gradients.\n", + "\n", + "### SI-SDR\n", + "\n", + "SI-SDR projects the estimate onto the target and reports the energy ratio in dB.\n", + "It is invariant to global loudness, which matters for a cappella where voices\n", + "differ widely in level:\n", + "\n", + "$$\\hat{s}_\\text{tgt} = \\frac{\\langle \\hat{s}, s \\rangle}{\\|s\\|^2} s, \\qquad \\text{SI-SDR} = 10\\log_{10}\\frac{\\|\\hat{s}_\\text{tgt}\\|^2}{\\|\\hat{s} - \\hat{s}_\\text{tgt}\\|^2}$$\n", + "\n", + "The projection step removes any DC offset before computing the ratio, so a\n", + "perfectly separated signal that is merely scaled up or down still scores the\n", + "maximum possible value. In practice, SI-SDR values above $+10\\ \\text{dB}$\n", + "indicate clearly separated sources; below $0\\ \\text{dB}$ the estimate is\n", + "dominated by leakage from other voices. We negate it to turn maximisation into\n", + "minimisation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8724e233", + "metadata": {}, + "outputs": [], + "source": [ + "@jaxtyped(typechecker=beartype)\n", + "def si_sdr(estimate: Float[Array, \"T\"], target: Float[Array, \"T\"], eps: float = 1e-8) -> Float[Array, \"\"]:\n", + " estimate = estimate - jnp.mean(estimate)\n", + " target = target - jnp.mean(target)\n", + " dot = jnp.sum(estimate * target)\n", + " s_target = (dot / (jnp.sum(target ** 2) + eps)) * target\n", + " e_noise = estimate - s_target\n", + " return 10.0 * jnp.log10(jnp.sum(s_target ** 2) / (jnp.sum(e_noise ** 2) + eps) + eps)" + ] + }, + { + "cell_type": "markdown", + "id": "138a4a3a", + "metadata": {}, + "source": [ + "### Multi-Resolution STFT Loss\n", + "\n", + "SI-SDR is blind to spectral texture: two signals can have the same SI-SDR yet\n", + "sound very different if one has unnatural resonances or missing harmonics.\n", + "Adding a frequency-domain term at three FFT scales $\\{512, 1024, 2048\\}$\n", + "addresses this at multiple time-frequency resolutions simultaneously.\n", + "\n", + "A small FFT ($512$) gives sharp time resolution — useful for detecting onset\n", + "smearing — while a large FFT ($2048$) gives fine frequency resolution — useful\n", + "for resolving individual harmonics in a choir. Using all three averages out\n", + "the inherent time-frequency tradeoff of any single STFT.\n", + "\n", + "Each scale contributes two terms:\n", + "\n", + "- **Spectral convergence** — the Frobenius-norm distance between magnitude\n", + " spectrograms, normalised by the target energy. This drives the gross shape\n", + " of the spectrum towards the reference.\n", + "- **Log-magnitude distance** — the mean absolute difference on a log scale.\n", + " Because human pitch perception is logarithmic, this term penalises errors in\n", + " quiet harmonics just as strongly as errors in loud ones.\n", + "\n", + "$$\\mathcal{L}_\\text{STFT} = \\frac{1}{3}\\sum_\\text{scale}\\left(\\underbrace{\\frac{\\||S| - |\\hat{S}|\\|_F}{\\||S|\\|_F}}_{\\text{spectral convergence}} + \\underbrace{\\text{mean}|\\log|S| - \\log|\\hat{S}||}_{\\text{log-magnitude}}\\right)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c008bc7", + "metadata": {}, + "outputs": [], + "source": [ + "@jaxtyped(typechecker=beartype)\n", + "def stft_mag(x: Float[Array, \"T\"], fft_size: int, hop: int, win_size: int) -> Float[Array, \"F K\"]:\n", + " window = jnp.hanning(win_size)\n", + " x_pad = jnp.pad(x, (fft_size // 2, fft_size // 2))\n", + " n_frames = (len(x_pad) - win_size) // hop + 1\n", + " idx = jnp.arange(win_size)[None, :] + jnp.arange(n_frames)[:, None] * hop\n", + " frames = x_pad[idx] * window\n", + " return jnp.abs(jnp.fft.rfft(frames, n=fft_size, axis=-1)).T # (F, K)\n", + "\n", + "@jaxtyped(typechecker=beartype)\n", + "def stft_loss_single(est: Float[Array, \"T\"], tgt: Float[Array, \"T\"], fft_size: int, hop: int, win: int) -> Float[Array, \"\"]:\n", + " em, tm = stft_mag(est, fft_size, hop, win), stft_mag(tgt, fft_size, hop, win)\n", + " sc = jnp.linalg.norm(tm - em) / (jnp.linalg.norm(tm) + 1e-8)\n", + " lm = jnp.mean(jnp.abs(jnp.log(em + 1e-8) - jnp.log(tm + 1e-8)))\n", + " return sc + lm\n", + "\n", + "@jaxtyped(typechecker=beartype)\n", + "def mr_stft_loss(est: Float[Array, \"T\"], tgt: Float[Array, \"T\"]) -> Float[Array, \"\"]:\n", + " scales = [(512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)]\n", + " return sum(stft_loss_single(est, tgt, *s) for s in scales) / len(scales)" + ] + }, + { + "cell_type": "markdown", + "id": "4d9eccac", + "metadata": {}, + "source": [ + "### Composite Loss\n", + "\n", + "The final objective combines the two terms, with the STFT loss weighted at\n", + "$0.5$ so that SI-SDR — which operates directly in the waveform domain and\n", + "carries the strongest perceptual signal — dominates early in training. The\n", + "STFT term then fills in spectral detail that SI-SDR cannot see. Both terms are\n", + "averaged across the $N$ stems before being averaged across the batch.\n", + "\n", + "Because the model already handles the batch dimension, `loss_fn` calls the\n", + "model once on the full `(B, T)` mixture and then vmaps the per-stem loss\n", + "over the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "102061d5", + "metadata": {}, + "outputs": [], + "source": [ + "@jaxtyped(typechecker=beartype)\n", + "def composite_loss(estimates: Float[Array, \"N T\"], targets: Float[Array, \"N T\"]) -> Float[Array, \"\"]:\n", + " def pair(est, tgt):\n", + " return -si_sdr(est, tgt) + 0.5 * mr_stft_loss(est, tgt)\n", + " return jnp.mean(jax.vmap(pair)(estimates, targets))\n", + "\n", + "def loss_fn(model, mixture: Float[Array, \"B T\"], targets: Float[Array, \"B N T\"]) -> Float[Array, \"\"]:\n", + " estimates = model(mixture) # (B, N, T)\n", + " return jnp.mean(jax.vmap(composite_loss)(estimates, targets))" + ] + }, + { + "cell_type": "markdown", + "id": "949c4933", + "metadata": {}, + "source": [ + "## Overfitting on JaCappella\n", + "\n", + "Before training on the full corpus, we overfit on a single batch. This is a\n", + "fast sanity check: if the model cannot memorise even one example, something is\n", + "wrong with the architecture, the loss, or the data pipeline. It is much\n", + "cheaper to discover this now than after a multi-hour training run.\n", + "\n", + "### Audio Logging\n", + "\n", + "The loss curve tells you the model is learning, but it does not tell you *what*\n", + "it is learning. Listening to the actual estimates at checkpoints is\n", + "irreplaceable: you can hear immediately whether the model is separating voices,\n", + "producing silence, or emitting noise. `log_audio_samples` writes one batch of\n", + "audio to TensorBoard — the raw mixture, each ground-truth stem, and the\n", + "corresponding model estimate — all normalised to a peak of $0.99$ so playback\n", + "levels are comparable across steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "486bcea2", + "metadata": {}, + "outputs": [], + "source": [ + "def log_audio_samples(model, loader, writer, global_step):\n", + " mixture, stems = next(iter(loader))\n", + " mix_np = np.array(mixture[0])\n", + " stems_np = np.array(stems[0])\n", + " est_np = np.array(model(mixture[0:1])[0]) # keep batch dim, then index out\n", + " scale = 0.99 / (np.max(np.abs(mix_np)) + 1e-8)\n", + " writer.add_audio(\"mixture\", mix_np * scale, global_step, sample_rate=SAMPLE_RATE)\n", + " for n in range(stems_np.shape[0]):\n", + " writer.add_audio(f\"true/{n}\", stems_np[n] * scale, global_step, sample_rate=SAMPLE_RATE)\n", + " for n in range(est_np.shape[0]):\n", + " est_scale = 0.99 / (np.max(np.abs(est_np[n])) + 1e-8)\n", + " writer.add_audio(f\"estimate/{n}\", est_np[n] * est_scale, global_step, sample_rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "markdown", + "id": "3f8ecc15", + "metadata": {}, + "source": [ + "### Optimizer and Training Loop\n", + "\n", + "We use AdamW with a global gradient-norm clip of $1.0$. Clipping is important\n", + "here because early in training the split layer and decoder produce near-random\n", + "outputs, which can generate very large gradients through the SI-SDR loss.\n", + "Weight decay of $10^{-2}$ provides mild regularisation to prevent any single\n", + "stem stream from collapsing to zero.\n", + "\n", + "The loop runs for 50 epochs, logging the scalar loss every 50 steps and\n", + "uploading a fresh set of audio samples at the end of each epoch. You can\n", + "monitor progress in TensorBoard with `tensorboard --logdir runs/overfit`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "039c4a90", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def step(model, optimizer, mixture, targets):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, mixture, targets)\n", + " optimizer.update(model, grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77d260d7", + "metadata": {}, + "outputs": [], + "source": [ + "import optax\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(3e-4, weight_decay=1e-2)))\n", + "\n", + "writer = SummaryWriter(\"runs/overfit\")\n", + "for epoch in range(200):\n", + " for mixture, targets in loader[:1]:\n", + " loss = step(model, optimizer, mixture, targets)\n", + " writer.add_scalar(\"loss\", float(loss), epoch)\n", + " if epoch % 20 == 0:\n", + " log_audio_samples(model, loader, writer, epoch)\n", + "writer.close()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/audio.md b/docs_nnx/guides/audio.md new file mode 100644 index 000000000..f1de94400 --- /dev/null +++ b/docs_nnx/guides/audio.md @@ -0,0 +1,579 @@ +--- +jupyter: + jupytext: + cell_metadata_filter: -all + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 +--- + +# Building a Choral Source Separator with SepReformer in JAX + +This tutorial demonstrates how to perform audio source separation using the `SepReformer` +architecture. We'll use `flax.nnx` for the neural network, with `beartype` for runtime type checking. + +```python +import numpy as np +import jax +import jax.numpy as jnp +from flax import nnx +from jaxtyping import Float, Array, jaxtyped +from beartype import beartype +import soundfile as sf +``` + +## The Task + +Given a mono mixture waveform $x \in \mathbb{R}^T$, produce $N$ separated stem +waveforms $\hat{s}_1, \ldots, \hat{s}_N \in \mathbb{R}^T$ such that +$\sum_n \hat{s}_n \approx x$ and each $\hat{s}_n$ matches one isolated voice +track. + +We use the **JaCappella** corpus (35 a cappella songs) via Hugging Face. Each +song has 5 isolated stems — `lead_vocal`, `soprano`, `alto`, `tenor`, `bass` — +and the mixture is their sum. + +```python +import librosa +from pathlib import Path + +STEM_NAMES = ("lead_vocal", "soprano", "alto", "tenor", "bass") +SAMPLE_RATE = 44100 + +class JaCappellaDataset: + def __init__(self, root) -> None: + self.root = Path(root) + self.sample_rate = 44100 + self.songs = sorted( + song + for genre in self.root.iterdir() + if genre.is_dir() and not genre.name.startswith(".") + for song in genre.iterdir() + if song.is_dir() and not song.name.startswith(".")) + + @jaxtyped(typechecker=beartype) + def _load_wav(self, path: Path) -> Float[np.ndarray, "T"]: + """Load a wav file, resample if needed, return mono float32.""" + audio, sr = sf.read(path, dtype="float32", always_2d=True) + audio = audio[:, 0] # take first channel if stereo + if sr != self.sample_rate: + audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) + return audio + + @jaxtyped(typechecker=beartype) + def _load_stems(self, song_dir: Path) -> Float[np.ndarray, "N T"]: + """Load all available stems for a song.""" + stems = [] + for i, name in enumerate(STEM_NAMES): + path = song_dir / f"{name}.wav" + stems.append(self._load_wav(path) if path.exists() else np.zeros(0, dtype=np.float32)) + max_len = max(len(s) for s in stems) + result = np.zeros((5, max_len), dtype=np.float32) + for i, s in enumerate(stems): + result[i, :len(s)] = s + return result + + @jaxtyped(typechecker=beartype) + def __getitem__(self, idx: int) -> tuple[Float[np.ndarray, "T"], Float[np.ndarray, "N T"]]: + """Load full song.""" + song_dir = self.songs[idx] + stems = self._load_stems(song_dir) + return stems.sum(axis=0), stems + + def __len__(self) -> int: + return len(self.songs) + +dataset = JaCappellaDataset("/space/samanklesaria/data/jacappella") +``` + +We'll feed the dataset to our model using the `grain` library. + +```python +import grain + +SEG_SAMPLES = SAMPLE_RATE * 2 # 2-second segments +BATCH_SIZE = 1 + +def extract_segment(item, rng): + mixture, stems = item + T = mixture.shape[0] + if T <= SEG_SAMPLES: + pad = SEG_SAMPLES - T + return np.pad(mixture, (0, pad)), np.pad(stems, ((0, 0), (0, pad))) + start = rng.integers(0, T - SEG_SAMPLES) + return mixture[start:start + SEG_SAMPLES], stems[:, start:start + SEG_SAMPLES] + +def batch_to_jax(items): + mixtures = jnp.array(np.stack([m for m, _ in items])) # (B, T) + stems = jnp.array(np.stack([s for _, s in items])) # (B, N, T) + return mixtures, stems + +loader = (grain.MapDataset.source(dataset) + .seed(0).shuffle() + .random_map(extract_segment) + .batch(BATCH_SIZE, drop_remainder=True, batch_fn=batch_to_jax)) +``` + +To make sure the data is loading correctly, we can sample a batch and log it to Tensorboard. + +```python +from tensorboardX import SummaryWriter + +mixture, stems = next(iter(loader)) +mixture = np.array(mixture[0]) # (T,) +stems = np.array(stems[0]) # (N, T) +writer = SummaryWriter("samples") +peak = np.max(np.abs(mixture)) +scale = 0.99 / peak if peak > 0 else 1.0 +writer.add_audio( + f"mixture", + mixture * scale, + sample_rate=dataset.sample_rate) +for n in range(stems.shape[0]): + writer.add_audio( + f"stem/{n}", + stems[n] * scale, + sample_rate=dataset.sample_rate) +writer.close() +``` + +```python +from IPython.display import display, HTML +display(HTML(open("mixture.html").read())) +display(HTML(open("stem.html").read())) +``` + +## Architecture + + +At a high level: +- The input waveform is encoded to a latent space using convolutional and transformer layers. +- The result gets split into separate pieces for each voice part +- Each piece is decoded back to a waveform by the same stack of transformer and convolutional layers. + +$$x \xrightarrow{\text{Conv}} h \xrightarrow{\text{Enc blocks}} h \xrightarrow{\text{Split}} \{h_n\} \xrightarrow{\text{Dec blocks}} \xrightarrow{\text{ConvT}} \{\hat{s}_n\}$$ + +## Convolutional Layers: Waveform → Latent Frames and Back + +A strided 1-D convolution converts the raw waveform into a sequence of latent +frames. With kernel $K$ and stride $S$: + +$$L = \left\lfloor \frac{T - K}{S} \right\rfloor + 1$$ + +The encoder applies a GELU after the convolution: + +$$h = \text{GELU}(W_\text{enc} * x), \quad h \in \mathbb{R}^{L \times C}$$ + +```python +class Encoder(nnx.Module): + def __init__(self, out_channels: int, *, rngs: nnx.Rngs): + self.conv = nnx.Conv(1, out_channels, kernel_size=(16,), strides=(8,), + padding='VALID', rngs=rngs) + + def __call__(self, x: Float[Array, "B T"]) -> Float[Array, "B L C"]: + h = self.conv(x[..., None]) # (B, T, 1) -> (B, L, C) + return jax.nn.gelu(h) +``` + +```python +class Decoder(nnx.Module): + def __init__(self, in_channels: int, *, rngs: nnx.Rngs): + self.conv_t = nnx.ConvTranspose(in_channels, 1, kernel_size=(16,), strides=(8,), + padding='VALID', rngs=rngs) + + def __call__(self, h: Float[Array, "B L C"]) -> Float[Array, "B T"]: + out = self.conv_t(h) # (B, L, C) -> (B, T, 1) + return out[..., 0] # (B, T) +``` + +Default: $K=16$, $S=8$, $C=256$. At 44.1 kHz a 2-second clip becomes +$L \approx 11{,}025$ frames. + + +## SNAKE Activation + +**SNAKE** (Liu et al., 2022) adds a learnable sinusoidal term that preserves the +periodic structure present in harmonic audio: + +$$\text{SNAKE}(x; \alpha) = x + \frac{1}{\alpha}\sin^2(\alpha x)$$ + +$\alpha$ is initialized to $\mathbf{1}$ (small perturbation at startup) and +learned per channel. SNAKE is used inside every feed-forward sub-layer in the +transformer blocks. Because $\alpha$ broadcasts along all leading dimensions, +`Snake` works with any number of batch axes. + +```python +class Snake(nnx.Module): + def __init__(self, features: int, *, rngs: nnx.Rngs): + self.alpha = nnx.Param(jnp.ones(features)) + + def __call__(self, x: Float[Array, "B F"]) -> Float[Array, "B F"]: + a = self.alpha.value + return x + (1.0 / (a + 1e-6)) * jnp.sin(a * x) ** 2 +``` + +```python +def feedforward(dim, ff_dim, rngs: nnx.Rngs): + return nnx.Sequential( + nnx.Linear(dim, ff_dim, rngs=rngs), + Snake(ff_dim, rngs=rngs), + nnx.Linear(ff_dim, dim, rngs=rngs)) +``` + +## Rotary Positional Embeddings + +Rotary positional embeddings (RoPE) encode position by rotating pairs of +features through position-dependent angles. This gives the transformer +translation-equivariant relative position information without adding +explicit position tokens. + +Flax provides `nnx.RoPE`, which precomputes cosine and sine frequency +tables once and stores them as module state. To use it with +`nnx.MultiHeadAttention`, pass `nnx.dot_product_attention_with_rope` +(with the `rope` argument partially applied) as the `attention_fn`: + +```python +import functools +``` + +```python +class TransformerBlock(nnx.Module): + def __init__( + self, dim: int, num_heads: int, ff_dim: int, + max_seq_len: int = 2048, *, rngs: nnx.Rngs + ): + self.norm1 = nnx.LayerNorm(dim, rngs=rngs) + self.norm2 = nnx.LayerNorm(dim, rngs=rngs) + head_dim = dim // num_heads + rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_seq_len) + self.attn = nnx.MultiHeadAttention( + num_heads=num_heads, + in_features=dim, + qkv_features=dim, + attention_fn=functools.partial( + nnx.dot_product_attention_with_rope, rope=rope), + decode=False, + rngs=rngs, + ) + self.ff = feedforward(dim, ff_dim, rngs=rngs) + self.scale1 = nnx.Param(jnp.full(dim, 1e-4)) + self.scale2 = nnx.Param(jnp.full(dim, 1e-4)) + + def __call__(self, x: Float[Array, "B S D"]) -> Float[Array, "B S D"]: + normed = self.norm1(x) + attn_out = self.attn(normed) + x = x + self.scale1[...] * attn_out + x = x + self.scale2[...] * self.ff(self.norm2(x)) + return x +``` + +## Stacking Transformer Layers: The Dual-Path Approach + +Full self-attention over $L \approx 11{,}000$ frames costs $O(L^2)$. The +dual-path trick (Luo & Mesgarani, 2020) splits this into two $O(L \cdot K)$ +passes: + +1. **Intra-chunk** — reshape to $(\ldots, M, K, C)$; each of the $M$ chunks + attends within itself. Captures local patterns. Cost: $O(M \cdot K^2)$. +2. **Inter-chunk** — swap to $(\ldots, K, M, C)$; each time-slot attends + across all $M$ chunks. Propagates global pitch/rhythm. Cost: $O(K \cdot M^2)$. + +Because `TransformerBlock` accepts `(B, S, D)`, the extra chunk axis +becomes just another batch dimension — no explicit `vmap` is needed. + +```python +class DualPathBlock(nnx.Module): + def __init__( + self, + dim: int, + num_heads: int, + ff_dim: int, + chunk_size: int = 64, + *, + rngs: nnx.Rngs, + ): + self.intra_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs) + self.inter_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs) + self.chunk_size = chunk_size + + def __call__(self, x: Float[Array, "B L C"]) -> Float[Array, "B L C"]: + B_shape, L, C = x.shape + K = self.chunk_size + + # Pad to multiple of chunk_size + pad_len = (K - L % K) % K + if pad_len > 0: + x = jnp.pad(x, [(0, 0)] * len(batch_shape) + [(0, pad_len), (0, 0)]) + + L_padded = x.shape[-2] + M = L_padded // K + + # (B, L, C) -> (B, M, K, C) + chunks = x.reshape(B_shape, M, K, C) + + # Intra-chunk: TransformerBlock sees (B, M) as batch, K as seq + chunks = self.intra_block(chunks) + + # Inter-chunk: swap M <-> K, attend, swap back + inter = jnp.swapaxes(chunks, -3, -2) # (B, K, M, C) + inter = self.inter_block(inter) + chunks = jnp.swapaxes(inter, -3, -2) # (B, M, K, C) + + out = chunks.reshape(B_shape, L_padded, C) + return out[..., :L, :] +``` + +## Splitting into Speaker Streams + +After the shared encoder blocks, a `SplitLayer` expands $(\ldots, L, C)$ into +$(\ldots, N, L, C)$. Splitting here lets each of the $N$ reconstruction stacks +specialize on one speaker while sharing parameters — the subsequent +`DualPathBlock` and `Decoder` layers simply treat the new $N$ axis as an +additional batch dimension. + +A GLU gate first refines the shared features before expanding: + +$$g, v = \text{split}(W_1 h), \quad h' = \sigma(g) \odot v, \quad \text{streams} = W_2 h' \;\text{reshaped to}\; (\ldots, N, L, C)$$ + +```python +class SplitLayer(nnx.Module): + def __init__(self, dim: int, num_stems: int, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(dim, dim * 2, rngs=rngs) + self.linear2 = nnx.Linear(dim, dim * num_stems, rngs=rngs) + self.num_stems = num_stems + + def __call__(self, x: Float[Array, "B L C"]) -> Float[Array, "B N L C"]: + h = self.linear1(x) # (B, L, 2C) + gate, val = jnp.split(h, 2, axis=-1) + h = jax.nn.sigmoid(gate) * val # (B, L, C) + h = self.linear2(h) # (B, L, N*C) + B_shape, L_dim, _ = h.shape + C = x.shape[-1] + h = h.reshape(B_shape, L_dim, self.num_stems, C) # (B, L, N, C) + return jnp.swapaxes(h, -3, -2) # (B, N, L, C) +``` + +## Full Forward Pass + +After `SplitLayer` produces $(\ldots, N, L, C)$, the reconstruction blocks +and decoder see $(B, N)$ as batch dimensions. The entire forward pass +runs without any explicit `vmap`. + +```python +class SepReformer(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + num_sep_blocks = 2 + num_rec_blocks = 2 + dim = 256 + num_heads = 8 + ff_dim = 1024 + chunk_size = 64 + + self.encoder = Encoder(dim, rngs=rngs) + self.decoder = Decoder(dim, rngs=rngs) + self.split = SplitLayer(dim, 5, rngs=rngs) + self.sep_blocks = [ + DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs) + for _ in range(num_sep_blocks) + ] + self.rec_blocks = [ + DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs) + for _ in range(num_rec_blocks) + ] + + def __call__(self, x: Float[Array, "B T"]) -> Float[Array, "B N T"]: + h = self.encoder(x) # (B, L, C) + for block in self.sep_blocks: + h = block(h) # (B, L, C) + stems = self.split(h) # (B, N, L, C) + for block in self.rec_blocks: + stems = block(stems) # (B, N, L, C) — N is a batch dim + out = self.decoder(stems) # (B, N, T') + # trim / pad to original length + T = x.shape[-1] + if out.shape[-1] > T: + out = out[..., :T] + elif out.shape[-1] < T: + pad_width = [(0, 0)] * (out.ndim - 1) + [(0, T - out.shape[-1])] + out = jnp.pad(out, pad_width) + return out # (B, N, T) + +model = SepReformer(rngs=nnx.Rngs(0)) +``` + + +## Loss Functions + +Supervising a source separator is not straightforward. A plain mean-squared +error (MSE) in the waveform domain penalises tiny timing offsets and global +loudness differences equally, so the model spends capacity chasing irrelevant +phase shifts rather than learning to separate voices. We instead use two +complementary objectives — one waveform-domain and one spectral — that together +give stable, perceptually meaningful gradients. + +### SI-SDR + +SI-SDR projects the estimate onto the target and reports the energy ratio in dB. +It is invariant to global loudness, which matters for a cappella where voices +differ widely in level: + +$$\hat{s}_\text{tgt} = \frac{\langle \hat{s}, s \rangle}{\|s\|^2} s, \qquad \text{SI-SDR} = 10\log_{10}\frac{\|\hat{s}_\text{tgt}\|^2}{\|\hat{s} - \hat{s}_\text{tgt}\|^2}$$ + +The projection step removes any DC offset before computing the ratio, so a +perfectly separated signal that is merely scaled up or down still scores the +maximum possible value. In practice, SI-SDR values above $+10\ \text{dB}$ +indicate clearly separated sources; below $0\ \text{dB}$ the estimate is +dominated by leakage from other voices. We negate it to turn maximisation into +minimisation. + +```python +@jaxtyped(typechecker=beartype) +def si_sdr(estimate: Float[Array, "T"], target: Float[Array, "T"], eps: float = 1e-8) -> Float[Array, ""]: + estimate = estimate - jnp.mean(estimate) + target = target - jnp.mean(target) + dot = jnp.sum(estimate * target) + s_target = (dot / (jnp.sum(target ** 2) + eps)) * target + e_noise = estimate - s_target + return 10.0 * jnp.log10(jnp.sum(s_target ** 2) / (jnp.sum(e_noise ** 2) + eps) + eps) +``` + +### Multi-Resolution STFT Loss + +SI-SDR is blind to spectral texture: two signals can have the same SI-SDR yet +sound very different if one has unnatural resonances or missing harmonics. +Adding a frequency-domain term at three FFT scales $\{512, 1024, 2048\}$ +addresses this at multiple time-frequency resolutions simultaneously. + +A small FFT ($512$) gives sharp time resolution — useful for detecting onset +smearing — while a large FFT ($2048$) gives fine frequency resolution — useful +for resolving individual harmonics in a choir. Using all three averages out +the inherent time-frequency tradeoff of any single STFT. + +Each scale contributes two terms: + +- **Spectral convergence** — the Frobenius-norm distance between magnitude + spectrograms, normalised by the target energy. This drives the gross shape + of the spectrum towards the reference. +- **Log-magnitude distance** — the mean absolute difference on a log scale. + Because human pitch perception is logarithmic, this term penalises errors in + quiet harmonics just as strongly as errors in loud ones. + +$$\mathcal{L}_\text{STFT} = \frac{1}{3}\sum_\text{scale}\left(\underbrace{\frac{\||S| - |\hat{S}|\|_F}{\||S|\|_F}}_{\text{spectral convergence}} + \underbrace{\text{mean}|\log|S| - \log|\hat{S}||}_{\text{log-magnitude}}\right)$$ + +```python +@jaxtyped(typechecker=beartype) +def stft_mag(x: Float[Array, "T"], fft_size: int, hop: int, win_size: int) -> Float[Array, "F K"]: + window = jnp.hanning(win_size) + x_pad = jnp.pad(x, (fft_size // 2, fft_size // 2)) + n_frames = (len(x_pad) - win_size) // hop + 1 + idx = jnp.arange(win_size)[None, :] + jnp.arange(n_frames)[:, None] * hop + frames = x_pad[idx] * window + return jnp.abs(jnp.fft.rfft(frames, n=fft_size, axis=-1)).T # (F, K) + +@jaxtyped(typechecker=beartype) +def stft_loss_single(est: Float[Array, "T"], tgt: Float[Array, "T"], fft_size: int, hop: int, win: int) -> Float[Array, ""]: + em, tm = stft_mag(est, fft_size, hop, win), stft_mag(tgt, fft_size, hop, win) + sc = jnp.linalg.norm(tm - em) / (jnp.linalg.norm(tm) + 1e-8) + lm = jnp.mean(jnp.abs(jnp.log(em + 1e-8) - jnp.log(tm + 1e-8))) + return sc + lm + +@jaxtyped(typechecker=beartype) +def mr_stft_loss(est: Float[Array, "T"], tgt: Float[Array, "T"]) -> Float[Array, ""]: + scales = [(512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)] + return sum(stft_loss_single(est, tgt, *s) for s in scales) / len(scales) +``` + +### Composite Loss + +The final objective combines the two terms, with the STFT loss weighted at +$0.5$ so that SI-SDR — which operates directly in the waveform domain and +carries the strongest perceptual signal — dominates early in training. The +STFT term then fills in spectral detail that SI-SDR cannot see. Both terms are +averaged across the $N$ stems before being averaged across the batch. + +Because the model already handles the batch dimension, `loss_fn` calls the +model once on the full `(B, T)` mixture and then vmaps the per-stem loss +over the batch. + +```python +@jaxtyped(typechecker=beartype) +def composite_loss(estimates: Float[Array, "N T"], targets: Float[Array, "N T"]) -> Float[Array, ""]: + def pair(est, tgt): + return -si_sdr(est, tgt) + 0.5 * mr_stft_loss(est, tgt) + return jnp.mean(jax.vmap(pair)(estimates, targets)) + +def loss_fn(model, mixture: Float[Array, "B T"], targets: Float[Array, "B N T"]) -> Float[Array, ""]: + estimates = model(mixture) # (B, N, T) + return jnp.mean(jax.vmap(composite_loss)(estimates, targets)) +``` + +## Overfitting on JaCappella + +Before training on the full corpus, we overfit on a single batch. This is a +fast sanity check: if the model cannot memorise even one example, something is +wrong with the architecture, the loss, or the data pipeline. It is much +cheaper to discover this now than after a multi-hour training run. + +### Audio Logging + +The loss curve tells you the model is learning, but it does not tell you *what* +it is learning. Listening to the actual estimates at checkpoints is +irreplaceable: you can hear immediately whether the model is separating voices, +producing silence, or emitting noise. `log_audio_samples` writes one batch of +audio to TensorBoard — the raw mixture, each ground-truth stem, and the +corresponding model estimate — all normalised to a peak of $0.99$ so playback +levels are comparable across steps. + +```python +def log_audio_samples(model, loader, writer, global_step): + mixture, stems = next(iter(loader)) + mix_np = np.array(mixture[0]) + stems_np = np.array(stems[0]) + est_np = np.array(model(mixture[0:1])[0]) # keep batch dim, then index out + scale = 0.99 / (np.max(np.abs(mix_np)) + 1e-8) + writer.add_audio("mixture", mix_np * scale, global_step, sample_rate=SAMPLE_RATE) + for n in range(stems_np.shape[0]): + writer.add_audio(f"true/{n}", stems_np[n] * scale, global_step, sample_rate=SAMPLE_RATE) + for n in range(est_np.shape[0]): + est_scale = 0.99 / (np.max(np.abs(est_np[n])) + 1e-8) + writer.add_audio(f"estimate/{n}", est_np[n] * est_scale, global_step, sample_rate=SAMPLE_RATE) +``` + +### Optimizer and Training Loop + +We use AdamW with a global gradient-norm clip of $1.0$. Clipping is important +here because early in training the split layer and decoder produce near-random +outputs, which can generate very large gradients through the SI-SDR loss. +Weight decay of $10^{-2}$ provides mild regularisation to prevent any single +stem stream from collapsing to zero. + +The loop runs for 50 epochs, logging the scalar loss every 50 steps and +uploading a fresh set of audio samples at the end of each epoch. You can +monitor progress in TensorBoard with `tensorboard --logdir runs/overfit`. + +```python +@nnx.jit +def step(model, optimizer, mixture, targets): + loss, grads = nnx.value_and_grad(loss_fn)(model, mixture, targets) + optimizer.update(model, grads) + return loss +``` + +```python +import optax + +optimizer = nnx.Optimizer(model, optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(3e-4, weight_decay=1e-2))) + +writer = SummaryWriter("runs/overfit") +for epoch in range(200): + for mixture, targets in loader[:1]: + loss = step(model, optimizer, mixture, targets) + writer.add_scalar("loss", float(loss), epoch) + if epoch % 20 == 0: + log_audio_samples(model, loader, writer, epoch) +writer.close() +``` diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 78a550411..fd02e8e39 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -115,6 +115,8 @@ from .nn.attention import MultiHeadAttention as MultiHeadAttention from .nn.attention import combine_masks as combine_masks from .nn.attention import dot_product_attention as dot_product_attention +from .nn.attention import dot_product_attention_with_rope as dot_product_attention_with_rope +from .nn.attention import RoPE as RoPE from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask from .nn.recurrent import RNNCellBase as RNNCellBase diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..cdc28bc9b 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -319,6 +319,106 @@ def reshape_4d(x): return out +class RoPE(Module): + """Rotary Position Embedding (RoPE). + + Precomputes and stores cosine and sine frequency tensors for rotary + positional encoding as described in "RoFormer: Enhanced Transformer + with Rotary Position Embedding" (https://arxiv.org/abs/2104.09864). + + Example usage:: + + >>> from flax import nnx + >>> import functools + >>> rope = nnx.RoPE(embedding_size=64, max_seq_len=128) + >>> layer = nnx.MultiHeadAttention( + ... num_heads=8, in_features=512, qkv_features=512, + ... attention_fn=functools.partial( + ... nnx.dot_product_attention_with_rope, rope=rope), + ... decode=False, rngs=nnx.Rngs(0)) + + Args: + embedding_size: Size of each embedding vector (i.e. head_dim). Must be even. + max_seq_len: Maximum sequence length for precomputed frequencies. + theta: Base frequency for sinusoidal functions. Default is 10000. + dtype: dtype for the precomputed cos/sin tables. + """ + + def __init__( + self, + embedding_size: int, + max_seq_len: int = 2048, + theta: float = 10000.0, + dtype: Dtype = jnp.float32, + ): + if embedding_size <= 0 or embedding_size % 2 != 0: + raise ValueError('`embedding_size` must be positive and even.') + self.embedding_size = embedding_size + self.theta = theta + freqs = 1.0 / ( + theta ** (jnp.arange(0.0, embedding_size, 2) / embedding_size) + ) + t = jnp.arange(float(max_seq_len)) + freqs_outer = jnp.outer(t, freqs) + self.cos_cached = nnx.Variable(jnp.cos(freqs_outer).astype(dtype)) + self.sin_cached = nnx.Variable(jnp.sin(freqs_outer).astype(dtype)) + + @staticmethod + def _rotate_half(x: Array) -> Array: + d_2 = x.shape[-1] // 2 + return jnp.concatenate([-x[..., d_2:], x[..., :d_2]], axis=-1) + + def __call__(self, x: Array) -> Array: + """Apply rotary positional encoding. + + Args: + x: Input array of shape ``(..., seq_length, embedding_size)``. + + Returns: + Array of same shape with RoPE applied. + """ + seq_len = x.shape[-2] + cos = jnp.tile(self.cos_cached[...][:seq_len], (1, 2)) + sin = jnp.tile(self.sin_cached[...][:seq_len], (1, 2)) + return x * cos + self._rotate_half(x) * sin + + +def dot_product_attention_with_rope( + query: Array, + key: Array, + value: Array, + *, + rope: RoPE, + **kwargs, +) -> Array: + """Dot-product attention with Rotary Position Embedding applied to query and key. + + This function has the same signature as :func:`dot_product_attention` + (plus a ``rope`` keyword argument) and can be used as the ``attention_fn`` + of :class:`MultiHeadAttention` via ``functools.partial``:: + + attention_fn=functools.partial(dot_product_attention_with_rope, rope=rope) + + Args: + query: queries of shape ``[batch..., q_length, num_heads, head_dim]``. + key: keys of shape ``[batch..., kv_length, num_heads, head_dim]``. + value: values of shape ``[batch..., kv_length, num_heads, v_dim]``. + rope: A :class:`RoPE` module instance. + **kwargs: Additional keyword arguments forwarded to + :func:`dot_product_attention`. + + Returns: + Output of shape ``[batch..., q_length, num_heads, v_dim]``. + """ + # query/key: [batch..., seq_length, num_heads, head_dim] + # RoPE.__call__ expects (..., seq_length, head_dim), so vmap over heads. + apply = jax.vmap(rope, in_axes=-2, out_axes=-2) + query = apply(query) + key = apply(key) + + return dot_product_attention(query, key, value, **kwargs) + + class MultiHeadAttention(Module): """Multi-head attention. diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 167fbf8cf..9d3eac4b9 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -383,6 +383,42 @@ def _run(m): np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3) +class TestRoPE(absltest.TestCase): + + def test_relative_position_invariance(self): + """Dot product of RoPE-rotated vectors depends only on relative position. + + For any offset d, should equal + . + """ + head_dim = 64 + max_len = 128 + rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_len) + + k1, k2 = jax.random.split(jax.random.key(7)) + q_vec = jax.random.normal(k1, (head_dim,)) + k_vec = jax.random.normal(k2, (head_dim,)) + + # Place q at position i and k at position j, then shift both by d. + i, j = 5, 12 + d = 37 + + def rope_at(vec, pos): + # Build a dummy sequence long enough, apply RoPE, extract one position. + seq = jnp.zeros((pos + 1, head_dim)).at[pos].set(vec) + return rope(seq)[pos] + + q_rot = rope_at(q_vec, i) + k_rot = rope_at(k_vec, j) + dot_original = jnp.dot(q_rot, k_rot) + + q_rot_shifted = rope_at(q_vec, i + d) + k_rot_shifted = rope_at(k_vec, j + d) + dot_shifted = jnp.dot(q_rot_shifted, k_rot_shifted) + + np.testing.assert_allclose(dot_original, dot_shifted, atol=1e-5) + + if __name__ == '__main__': absltest.main()