diff --git a/docs_nnx/guides/audio.ipynb b/docs_nnx/guides/audio.ipynb new file mode 100644 index 000000000..dc9c2b1e4 --- /dev/null +++ b/docs_nnx/guides/audio.ipynb @@ -0,0 +1,807 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1794712e", + "metadata": {}, + "source": [ + "# Building a Choral Source Separator with SepReformer in JAX\n", + "\n", + "This tutorial demonstrates how to perform audio source separation using the `SepReformer`\n", + "architecture. The original paper can be found at https://arxiv.org/abs/2406.05983; the authors' full source code is at https://github.com/dmlguq456/SepReformer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5b2848", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import nnx\n", + "from jaxtyping import Float, Array\n", + "import soundfile as sf" + ] + }, + { + "cell_type": "markdown", + "id": "e011d75a", + "metadata": {}, + "source": [ + "## The Task\n", + "\n", + "Given a mono mixture waveform $x \\in \\mathbb{R}^T$, produce $N$ separated stem\n", + "waveforms $\\hat{s}_1, \\ldots, \\hat{s}_N \\in \\mathbb{R}^T$ such that\n", + "$\\sum_n \\hat{s}_n \\approx x$ and each $\\hat{s}_n$ matches one isolated voice\n", + "track.\n", + "\n", + "We use the **JaCappella** corpus (35 a cappella songs) via Hugging Face. Each\n", + "song has 5 isolated stems — `lead_vocal`, `soprano`, `alto`, `tenor`, `bass` —\n", + "and the mixture is their sum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e2c106f", + "metadata": {}, + "outputs": [], + "source": [ + "import librosa\n", + "from pathlib import Path\n", + "\n", + "STEM_NAMES = (\"lead_vocal\", \"soprano\", \"alto\", \"tenor\", \"bass\")\n", + "SAMPLE_RATE = 44100\n", + "\n", + "class JaCappellaDataset:\n", + " def __init__(self, root) -> None:\n", + " self.root = Path(root)\n", + " self.sample_rate = 44100\n", + " self.songs = sorted(\n", + " song\n", + " for genre in self.root.iterdir()\n", + " if genre.is_dir() and not genre.name.startswith(\".\")\n", + " for song in genre.iterdir()\n", + " if song.is_dir() and not song.name.startswith(\".\"))\n", + "\n", + " def _load_wav(self, path: Path) -> Float[np.ndarray, \"T\"]:\n", + " \"\"\"Load a wav file, resample if needed, return mono float32.\"\"\"\n", + " audio, sr = sf.read(path, dtype=\"float32\", always_2d=True)\n", + " audio = audio[:, 0] # take first channel if stereo\n", + " if sr != self.sample_rate:\n", + " audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)\n", + " return audio\n", + "\n", + " def _load_stems(self, song_dir: Path) -> Float[np.ndarray, \"N T\"]:\n", + " \"\"\"Load all available stems for a song.\"\"\"\n", + " stems = []\n", + " for i, name in enumerate(STEM_NAMES):\n", + " path = song_dir / f\"{name}.wav\"\n", + " stems.append(self._load_wav(path) if path.exists() else np.zeros(0, dtype=np.float32))\n", + " max_len = max(len(s) for s in stems)\n", + " result = np.zeros((5, max_len), dtype=np.float32)\n", + " for i, s in enumerate(stems):\n", + " result[i, :len(s)] = s\n", + " return result\n", + "\n", + " def __getitem__(self, idx: int) -> tuple[Float[np.ndarray, \"T\"], Float[np.ndarray, \"N T\"]]:\n", + " \"\"\"Load full song.\"\"\"\n", + " song_dir = self.songs[idx]\n", + " stems = self._load_stems(song_dir)\n", + " return stems.sum(axis=0), stems\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.songs)\n", + "\n", + "dataset = JaCappellaDataset(\"data/jacappella\")" + ] + }, + { + "cell_type": "markdown", + "id": "a811807f", + "metadata": {}, + "source": [ + "We'll feed the dataset to our model using the `grain` library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c66dea25", + "metadata": {}, + "outputs": [], + "source": [ + "import grain\n", + "\n", + "SEG_SAMPLES = SAMPLE_RATE * 2 # 2-second segments\n", + "BATCH_SIZE = 1\n", + "\n", + "def extract_segment(item, rng):\n", + " mixture, stems = item\n", + " T = mixture.shape[0]\n", + " if T <= SEG_SAMPLES:\n", + " pad = SEG_SAMPLES - T\n", + " return np.pad(mixture, (0, pad)), np.pad(stems, ((0, 0), (0, pad)))\n", + " start = rng.integers(0, T - SEG_SAMPLES)\n", + " return mixture[start:start + SEG_SAMPLES], stems[:, start:start + SEG_SAMPLES]\n", + "\n", + "def batch_to_jax(items):\n", + " mixtures = jnp.array(np.stack([m for m, _ in items])) # (B, T)\n", + " stems = jnp.array(np.stack([s for _, s in items])) # (B, N, T)\n", + " return mixtures, stems\n", + "\n", + "loader = (grain.MapDataset.source(dataset)\n", + " .seed(0).shuffle()\n", + " .random_map(extract_segment)\n", + " .batch(BATCH_SIZE, drop_remainder=True, batch_fn=batch_to_jax))" + ] + }, + { + "cell_type": "markdown", + "id": "0200fb2e", + "metadata": {}, + "source": [ + "To make sure the data is loading correctly, we can sample a batch and log it to Tensorboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62cdbc65", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorboardX import SummaryWriter\n", + "\n", + "mixture, stems = next(iter(loader))\n", + "mixture = np.array(mixture[0]) # (T,)\n", + "stems = np.array(stems[0]) # (N, T)\n", + "writer = SummaryWriter(\"samples\")\n", + "peak = np.max(np.abs(mixture))\n", + "scale = 0.99 / peak if peak > 0 else 1.0\n", + "writer.add_audio(\n", + " f\"mixture\",\n", + " mixture * scale,\n", + " sample_rate=dataset.sample_rate)\n", + "for n in range(stems.shape[0]):\n", + " writer.add_audio(\n", + " f\"stem/{n}\",\n", + " stems[n] * scale,\n", + " sample_rate=dataset.sample_rate)\n", + "writer.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0588c381", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(open(\"mixture.html\").read()))\n", + "display(HTML(open(\"stem.html\").read()))" + ] + }, + { + "cell_type": "markdown", + "id": "01d25e37", + "metadata": {}, + "source": [ + "## Architecture" + ] + }, + { + "cell_type": "markdown", + "id": "81382b61", + "metadata": {}, + "source": [ + "At a high level:\n", + "- The input waveform is encoded to a latent space using convolutional and transformer layers.\n", + "- The result gets split into separate pieces for each voice part\n", + "- Each piece is decoded back to a waveform by the same stack of transformer and convolutional layers.\n", + "\n", + "$$x \\xrightarrow{\\text{Conv}} h \\xrightarrow{\\text{Enc blocks}} h \\xrightarrow{\\text{Split}} \\{h_n\\} \\xrightarrow{\\text{Dec blocks}} \\xrightarrow{\\text{ConvT}} \\{\\hat{s}_n\\}$$\n", + "\n", + "## Convolutional Layers: Waveform → Latent Frames and Back\n", + "\n", + "A strided 1-D convolution converts the raw waveform into a sequence of latent\n", + "frames. With kernel $K$ and stride $S$:\n", + "\n", + "$$L = \\left\\lfloor \\frac{T - K}{S} \\right\\rfloor + 1$$\n", + "\n", + "The encoder applies a GELU after the convolution:\n", + "\n", + "$$h = \\text{GELU}(W_\\text{enc} * x), \\quad h \\in \\mathbb{R}^{L \\times C}$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3bcf612", + "metadata": {}, + "outputs": [], + "source": [ + "class Encoder(nnx.Module):\n", + " def __init__(self, out_channels: int, *, rngs: nnx.Rngs):\n", + " self.conv = nnx.Conv(1, out_channels, kernel_size=(16,), strides=(8,),\n", + " padding='VALID', rngs=rngs)\n", + "\n", + " def __call__(self, x: Float[Array, \"B T\"]) -> Float[Array, \"B L C\"]:\n", + " h = self.conv(x[..., None]) # (B, T, 1) -> (B, L, C)\n", + " return jax.nn.gelu(h)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "456c95e1", + "metadata": {}, + "outputs": [], + "source": [ + "class Decoder(nnx.Module):\n", + " def __init__(self, in_channels: int, *, rngs: nnx.Rngs):\n", + " self.conv_t = nnx.ConvTranspose(in_channels, 1, kernel_size=(16,), strides=(8,),\n", + " padding='VALID', rngs=rngs)\n", + "\n", + " def __call__(self, h: Float[Array, \"B L C\"]) -> Float[Array, \"B T\"]:\n", + " out = self.conv_t(h) # (B, L, C) -> (B, T, 1)\n", + " return out[..., 0] # (B, T)" + ] + }, + { + "cell_type": "markdown", + "id": "e1310de1", + "metadata": {}, + "source": [ + "Default: $K=16$, $S=8$, $C=256$. At 44.1 kHz a 2-second clip becomes\n", + "$L \\approx 11{,}025$ frames." + ] + }, + { + "cell_type": "markdown", + "id": "908b4bd2", + "metadata": {}, + "source": [ + "## SNAKE Activation\n", + "\n", + "**SNAKE** (Liu et al., 2022) adds a learnable sinusoidal term that preserves the\n", + "periodic structure present in harmonic audio:\n", + "\n", + "$$\\text{SNAKE}(x; \\alpha) = x + \\frac{1}{\\alpha}\\sin^2(\\alpha x)$$\n", + "\n", + "$\\alpha$ is initialized to $\\mathbf{1}$ (small perturbation at startup) and\n", + "learned per channel. SNAKE is used inside every feed-forward sub-layer in the\n", + "transformer blocks. Because $\\alpha$ broadcasts along all leading dimensions,\n", + "`Snake` works with any number of batch axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e3f013", + "metadata": {}, + "outputs": [], + "source": [ + "class Snake(nnx.Module):\n", + " def __init__(self, features: int, *, rngs: nnx.Rngs):\n", + " self.alpha = nnx.Param(jnp.ones(features))\n", + "\n", + " def __call__(self, x: Float[Array, \"B F\"]) -> Float[Array, \"B F\"]:\n", + " a = self.alpha.value\n", + " return x + (1.0 / (a + 1e-6)) * jnp.sin(a * x) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fe5f041", + "metadata": {}, + "outputs": [], + "source": [ + "def feedforward(dim, ff_dim, rngs: nnx.Rngs):\n", + " return nnx.Sequential(\n", + " nnx.Linear(dim, ff_dim, rngs=rngs),\n", + " Snake(ff_dim, rngs=rngs),\n", + " nnx.Linear(ff_dim, dim, rngs=rngs))" + ] + }, + { + "cell_type": "markdown", + "id": "bcb20477", + "metadata": {}, + "source": [ + "## Rotary Positional Embeddings\n", + "\n", + "Rotary positional embeddings (RoPE) encode position by rotating pairs of\n", + "features through position-dependent angles. This gives the transformer\n", + "translation-equivariant relative position information without adding\n", + "explicit position tokens.\n", + "\n", + "Flax provides `nnx.RoPE`, which precomputes cosine and sine frequency\n", + "tables once and stores them as module state. To use it with\n", + "`nnx.MultiHeadAttention`, pass `nnx.dot_product_attention_with_rope`\n", + "(with the `rope` argument partially applied) as the `attention_fn`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1c966a8", + "metadata": {}, + "outputs": [], + "source": [ + "import functools" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ed30a18", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlock(nnx.Module):\n", + " def __init__(\n", + " self, dim: int, num_heads: int, ff_dim: int,\n", + " max_seq_len: int = 2048, *, rngs: nnx.Rngs\n", + " ):\n", + " self.norm1 = nnx.LayerNorm(dim, rngs=rngs)\n", + " self.norm2 = nnx.LayerNorm(dim, rngs=rngs)\n", + " head_dim = dim // num_heads\n", + " rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_seq_len)\n", + " self.attn = nnx.MultiHeadAttention(\n", + " num_heads=num_heads,\n", + " in_features=dim,\n", + " qkv_features=dim,\n", + " attention_fn=functools.partial(\n", + " nnx.dot_product_attention_with_rope, rope=rope),\n", + " decode=False,\n", + " rngs=rngs,\n", + " )\n", + " self.ff = feedforward(dim, ff_dim, rngs=rngs)\n", + " self.scale1 = nnx.Param(jnp.full(dim, 1e-4))\n", + " self.scale2 = nnx.Param(jnp.full(dim, 1e-4))\n", + "\n", + " def __call__(self, x: Float[Array, \"B S D\"]) -> Float[Array, \"B S D\"]:\n", + " normed = self.norm1(x)\n", + " attn_out = self.attn(normed)\n", + " x = x + self.scale1[...] * attn_out\n", + " x = x + self.scale2[...] * self.ff(self.norm2(x))\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "68ded67a", + "metadata": {}, + "source": [ + "## Stacking Transformer Layers: The Dual-Path Approach\n", + "\n", + "Full self-attention over $L \\approx 11{,}000$ frames costs $O(L^2)$. The\n", + "dual-path trick (Luo & Mesgarani, 2020) splits this into two $O(L \\cdot K)$\n", + "passes:\n", + "\n", + "1. **Intra-chunk** — reshape to $(\\ldots, M, K, C)$; each of the $M$ chunks\n", + " attends within itself. Captures local patterns. Cost: $O(M \\cdot K^2)$.\n", + "2. **Inter-chunk** — swap to $(\\ldots, K, M, C)$; each time-slot attends\n", + " across all $M$ chunks. Propagates global pitch/rhythm. Cost: $O(K \\cdot M^2)$.\n", + "\n", + "Because `TransformerBlock` accepts `(B, S, D)`, the extra chunk axis\n", + "becomes just another batch dimension — no explicit `vmap` is needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64b35bbd", + "metadata": {}, + "outputs": [], + "source": [ + "class DualPathBlock(nnx.Module):\n", + " def __init__(\n", + " self,\n", + " dim: int,\n", + " num_heads: int,\n", + " ff_dim: int,\n", + " chunk_size: int = 64,\n", + " *,\n", + " rngs: nnx.Rngs,\n", + " ):\n", + " self.intra_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs)\n", + " self.inter_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs)\n", + " self.chunk_size = chunk_size\n", + "\n", + " def __call__(self, x: Float[Array, \"B L C\"]) -> Float[Array, \"B L C\"]:\n", + " B_shape, L, C = x.shape\n", + " K = self.chunk_size\n", + "\n", + " # Pad to multiple of chunk_size\n", + " pad_len = (K - L % K) % K\n", + " if pad_len > 0:\n", + " x = jnp.pad(x, [(0, 0)] * len(batch_shape) + [(0, pad_len), (0, 0)])\n", + "\n", + " L_padded = x.shape[-2]\n", + " M = L_padded // K\n", + "\n", + " # (B, L, C) -> (B, M, K, C)\n", + " chunks = x.reshape(B_shape, M, K, C)\n", + "\n", + " # Intra-chunk: TransformerBlock sees (B, M) as batch, K as seq\n", + " chunks = self.intra_block(chunks)\n", + "\n", + " # Inter-chunk: swap M <-> K, attend, swap back\n", + " inter = jnp.swapaxes(chunks, -3, -2) # (B, K, M, C)\n", + " inter = self.inter_block(inter)\n", + " chunks = jnp.swapaxes(inter, -3, -2) # (B, M, K, C)\n", + "\n", + " out = chunks.reshape(B_shape, L_padded, C)\n", + " return out[..., :L, :]" + ] + }, + { + "cell_type": "markdown", + "id": "8cbd5f34", + "metadata": {}, + "source": [ + "## Splitting into Speaker Streams\n", + "\n", + "After the shared encoder blocks, a `SplitLayer` expands $(\\ldots, L, C)$ into\n", + "$(\\ldots, N, L, C)$. Splitting here lets each of the $N$ reconstruction stacks\n", + "specialize on one speaker while sharing parameters — the subsequent\n", + "`DualPathBlock` and `Decoder` layers simply treat the new $N$ axis as an\n", + "additional batch dimension.\n", + "\n", + "A GLU gate first refines the shared features before expanding:\n", + "\n", + "$$g, v = \\text{split}(W_1 h), \\quad h' = \\sigma(g) \\odot v, \\quad \\text{streams} = W_2 h' \\;\\text{reshaped to}\\; (\\ldots, N, L, C)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c83d2e", + "metadata": {}, + "outputs": [], + "source": [ + "class SplitLayer(nnx.Module):\n", + " def __init__(self, dim: int, num_stems: int, *, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim * 2, rngs=rngs)\n", + " self.linear2 = nnx.Linear(dim, dim * num_stems, rngs=rngs)\n", + " self.num_stems = num_stems\n", + "\n", + " def __call__(self, x: Float[Array, \"B L C\"]) -> Float[Array, \"B N L C\"]:\n", + " h = self.linear1(x) # (B, L, 2C)\n", + " gate, val = jnp.split(h, 2, axis=-1)\n", + " h = jax.nn.sigmoid(gate) * val # (B, L, C)\n", + " h = self.linear2(h) # (B, L, N*C)\n", + " B_shape, L_dim, _ = h.shape\n", + " C = x.shape[-1]\n", + " h = h.reshape(B_shape, L_dim, self.num_stems, C) # (B, L, N, C)\n", + " return jnp.swapaxes(h, -3, -2) # (B, N, L, C)" + ] + }, + { + "cell_type": "markdown", + "id": "74259e0f", + "metadata": {}, + "source": [ + "## Full Forward Pass\n", + "\n", + "After `SplitLayer` produces $(\\ldots, N, L, C)$, the reconstruction blocks\n", + "and decoder see $(B, N)$ as batch dimensions. The entire forward pass\n", + "runs without any explicit `vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "924e4396", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class SepReformer(nnx.Module):\n", + " def __init__(self, *, rngs: nnx.Rngs):\n", + " num_sep_blocks = 2\n", + " num_rec_blocks = 2\n", + " dim = 256\n", + " num_heads = 8\n", + " ff_dim = 1024\n", + " chunk_size = 64\n", + "\n", + " self.encoder = Encoder(dim, rngs=rngs)\n", + " self.decoder = Decoder(dim, rngs=rngs)\n", + " self.split = SplitLayer(dim, 5, rngs=rngs)\n", + " self.sep_blocks = [\n", + " DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs)\n", + " for _ in range(num_sep_blocks)\n", + " ]\n", + " self.rec_blocks = [\n", + " DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs)\n", + " for _ in range(num_rec_blocks)\n", + " ]\n", + "\n", + " def __call__(self, x: Float[Array, \"B T\"]) -> Float[Array, \"B N T\"]:\n", + " h = self.encoder(x) # (B, L, C)\n", + " for block in self.sep_blocks:\n", + " h = block(h) # (B, L, C)\n", + " stems = self.split(h) # (B, N, L, C)\n", + " for block in self.rec_blocks:\n", + " stems = block(stems) # (B, N, L, C) — N is a batch dim\n", + " out = self.decoder(stems) # (B, N, T')\n", + " # trim / pad to original length\n", + " T = x.shape[-1]\n", + " if out.shape[-1] > T:\n", + " out = out[..., :T]\n", + " elif out.shape[-1] < T:\n", + " pad_width = [(0, 0)] * (out.ndim - 1) + [(0, T - out.shape[-1])]\n", + " out = jnp.pad(out, pad_width)\n", + " return out # (B, N, T)\n", + "\n", + "model = SepReformer(rngs=nnx.Rngs(0))" + ] + }, + { + "cell_type": "markdown", + "id": "ceb7d8e1", + "metadata": {}, + "source": [ + "## Loss Functions\n", + "\n", + "Supervising a source separator is not straightforward. A plain mean-squared\n", + "error (MSE) in the waveform domain penalises tiny timing offsets and global\n", + "loudness differences equally, so the model spends capacity chasing irrelevant\n", + "phase shifts rather than learning to separate voices. We instead use two\n", + "complementary objectives — one waveform-domain and one spectral — that together\n", + "give stable, perceptually meaningful gradients.\n", + "\n", + "### SI-SDR\n", + "\n", + "SI-SDR projects the estimate onto the target and reports the energy ratio in dB.\n", + "It is invariant to global loudness, which matters for a cappella where voices\n", + "differ widely in level:\n", + "\n", + "$$\\hat{s}_\\text{tgt} = \\frac{\\langle \\hat{s}, s \\rangle}{\\|s\\|^2} s, \\qquad \\text{SI-SDR} = 10\\log_{10}\\frac{\\|\\hat{s}_\\text{tgt}\\|^2}{\\|\\hat{s} - \\hat{s}_\\text{tgt}\\|^2}$$\n", + "\n", + "The projection step removes any DC offset before computing the ratio, so a\n", + "perfectly separated signal that is merely scaled up or down still scores the\n", + "maximum possible value. In practice, SI-SDR values above $+10\\ \\text{dB}$\n", + "indicate clearly separated sources; below $0\\ \\text{dB}$ the estimate is\n", + "dominated by leakage from other voices. We negate it to turn maximisation into\n", + "minimisation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8724e233", + "metadata": {}, + "outputs": [], + "source": [ + "def si_sdr(estimate: Float[Array, \"T\"], target: Float[Array, \"T\"], eps: float = 1e-8) -> Float[Array, \"\"]:\n", + " estimate = estimate - jnp.mean(estimate)\n", + " target = target - jnp.mean(target)\n", + " dot = jnp.sum(estimate * target)\n", + " s_target = (dot / (jnp.sum(target ** 2) + eps)) * target\n", + " e_noise = estimate - s_target\n", + " return 10.0 * jnp.log10(jnp.sum(s_target ** 2) / (jnp.sum(e_noise ** 2) + eps) + eps)" + ] + }, + { + "cell_type": "markdown", + "id": "138a4a3a", + "metadata": {}, + "source": [ + "### Multi-Resolution STFT Loss\n", + "\n", + "SI-SDR is blind to spectral texture: two signals can have the same SI-SDR yet\n", + "sound very different if one has unnatural resonances or missing harmonics.\n", + "Adding a frequency-domain term at three FFT scales $\\{512, 1024, 2048\\}$\n", + "addresses this at multiple time-frequency resolutions simultaneously.\n", + "\n", + "A small FFT ($512$) gives sharp time resolution — useful for detecting onset\n", + "smearing — while a large FFT ($2048$) gives fine frequency resolution — useful\n", + "for resolving individual harmonics in a choir. Using all three averages out\n", + "the inherent time-frequency tradeoff of any single STFT.\n", + "\n", + "Each scale contributes two terms:\n", + "\n", + "- **Spectral convergence** — the Frobenius-norm distance between magnitude\n", + " spectrograms, normalised by the target energy. This drives the gross shape\n", + " of the spectrum towards the reference.\n", + "- **Log-magnitude distance** — the mean absolute difference on a log scale.\n", + " Because human pitch perception is logarithmic, this term penalises errors in\n", + " quiet harmonics just as strongly as errors in loud ones.\n", + "\n", + "$$\\mathcal{L}_\\text{STFT} = \\frac{1}{3}\\sum_\\text{scale}\\left(\\underbrace{\\frac{\\||S| - |\\hat{S}|\\|_F}{\\||S|\\|_F}}_{\\text{spectral convergence}} + \\underbrace{\\text{mean}|\\log|S| - \\log|\\hat{S}||}_{\\text{log-magnitude}}\\right)$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c008bc7", + "metadata": {}, + "outputs": [], + "source": [ + "def stft_mag(x: Float[Array, \"T\"], fft_size: int, hop: int, win_size: int) -> Float[Array, \"F K\"]:\n", + " window = jnp.hanning(win_size)\n", + " x_pad = jnp.pad(x, (fft_size // 2, fft_size // 2))\n", + " n_frames = (len(x_pad) - win_size) // hop + 1\n", + " idx = jnp.arange(win_size)[None, :] + jnp.arange(n_frames)[:, None] * hop\n", + " frames = x_pad[idx] * window\n", + " return jnp.abs(jnp.fft.rfft(frames, n=fft_size, axis=-1)).T # (F, K)\n", + "\n", + "def stft_loss_single(est: Float[Array, \"T\"], tgt: Float[Array, \"T\"], fft_size: int, hop: int, win: int) -> Float[Array, \"\"]:\n", + " em, tm = stft_mag(est, fft_size, hop, win), stft_mag(tgt, fft_size, hop, win)\n", + " sc = jnp.linalg.norm(tm - em) / (jnp.linalg.norm(tm) + 1e-8)\n", + " lm = jnp.mean(jnp.abs(jnp.log(em + 1e-8) - jnp.log(tm + 1e-8)))\n", + " return sc + lm\n", + "\n", + "def mr_stft_loss(est: Float[Array, \"T\"], tgt: Float[Array, \"T\"]) -> Float[Array, \"\"]:\n", + " scales = [(512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)]\n", + " return sum(stft_loss_single(est, tgt, *s) for s in scales) / len(scales)" + ] + }, + { + "cell_type": "markdown", + "id": "4d9eccac", + "metadata": {}, + "source": [ + "### Composite Loss\n", + "\n", + "The final objective combines the two terms, with the STFT loss weighted at\n", + "$0.5$ so that SI-SDR — which operates directly in the waveform domain and\n", + "carries the strongest perceptual signal — dominates early in training. The\n", + "STFT term then fills in spectral detail that SI-SDR cannot see. Both terms are\n", + "averaged across the $N$ stems before being averaged across the batch.\n", + "\n", + "Because the model already handles the batch dimension, `loss_fn` calls the\n", + "model once on the full `(B, T)` mixture and then vmaps the per-stem loss\n", + "over the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "102061d5", + "metadata": {}, + "outputs": [], + "source": [ + "def composite_loss(estimates: Float[Array, \"N T\"], targets: Float[Array, \"N T\"]) -> Float[Array, \"\"]:\n", + " def pair(est, tgt):\n", + " return -si_sdr(est, tgt) + 0.5 * mr_stft_loss(est, tgt)\n", + " return jnp.mean(jax.vmap(pair)(estimates, targets))\n", + "\n", + "def loss_fn(model, mixture: Float[Array, \"B T\"], targets: Float[Array, \"B N T\"]) -> Float[Array, \"\"]:\n", + " estimates = model(mixture) # (B, N, T)\n", + " return jnp.mean(jax.vmap(composite_loss)(estimates, targets))" + ] + }, + { + "cell_type": "markdown", + "id": "949c4933", + "metadata": {}, + "source": [ + "## Overfitting on JaCappella\n", + "\n", + "Before training on the full corpus, we overfit on a single batch. This is a\n", + "fast sanity check: if the model cannot memorise even one example, something is\n", + "wrong with the architecture, the loss, or the data pipeline. It is much\n", + "cheaper to discover this now than after a multi-hour training run.\n", + "\n", + "### Audio Logging\n", + "\n", + "The loss curve tells you the model is learning, but it does not tell you *what*\n", + "it is learning. Listening to the actual estimates at checkpoints is\n", + "irreplaceable: you can hear immediately whether the model is separating voices,\n", + "producing silence, or emitting noise. `log_audio_samples` writes one batch of\n", + "audio to TensorBoard — the raw mixture, each ground-truth stem, and the\n", + "corresponding model estimate — all normalised to a peak of $0.99$ so playback\n", + "levels are comparable across steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "486bcea2", + "metadata": {}, + "outputs": [], + "source": [ + "def log_audio_samples(model, loader, writer, global_step):\n", + " mixture, stems = next(iter(loader))\n", + " mix_np = np.array(mixture[0])\n", + " stems_np = np.array(stems[0])\n", + " est_np = np.array(model(mixture[0:1])[0]) # keep batch dim, then index out\n", + " scale = 0.99 / (np.max(np.abs(mix_np)) + 1e-8)\n", + " writer.add_audio(\"mixture\", mix_np * scale, global_step, sample_rate=SAMPLE_RATE)\n", + " for n in range(stems_np.shape[0]):\n", + " writer.add_audio(f\"true/{n}\", stems_np[n] * scale, global_step, sample_rate=SAMPLE_RATE)\n", + " for n in range(est_np.shape[0]):\n", + " est_scale = 0.99 / (np.max(np.abs(est_np[n])) + 1e-8)\n", + " writer.add_audio(f\"estimate/{n}\", est_np[n] * est_scale, global_step, sample_rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "markdown", + "id": "3f8ecc15", + "metadata": {}, + "source": [ + "### Optimizer and Training Loop\n", + "\n", + "We use AdamW with a global gradient-norm clip of $1.0$. Clipping is important\n", + "here because early in training the split layer and decoder produce near-random\n", + "outputs, which can generate very large gradients through the SI-SDR loss.\n", + "Weight decay of $10^{-2}$ provides mild regularisation to prevent any single\n", + "stem stream from collapsing to zero.\n", + "\n", + "The loop runs for 50 epochs, logging the scalar loss every 50 steps and\n", + "uploading a fresh set of audio samples at the end of each epoch. You can\n", + "monitor progress in TensorBoard with `tensorboard --logdir runs/overfit`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "039c4a90", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def step(model, optimizer, mixture, targets):\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model, mixture, targets)\n", + " optimizer.update(model, grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77d260d7", + "metadata": {}, + "outputs": [], + "source": [ + "import optax\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(3e-4, weight_decay=1e-2)))\n", + "\n", + "writer = SummaryWriter(\"runs/overfit\")\n", + "for epoch in range(200):\n", + " for mixture, targets in loader[:1]:\n", + " loss = step(model, optimizer, mixture, targets)\n", + " writer.add_scalar(\"loss\", float(loss), epoch)\n", + " if epoch % 20 == 0:\n", + " log_audio_samples(model, loader, writer, epoch)\n", + "writer.close()" + ] + }, + { + "cell_type": "markdown", + "id": "7e22fb1d", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This tutorial walked through building a choral source separator from scratch in Flax NNX. We covered the key ingredients: a convolutional encoder/decoder pair for moving between waveforms and latent frames, SNAKE activations for preserving harmonic structure, dual-path transformer blocks for efficient long-sequence attention, and a GLU-gated split layer for dividing shared representations into per-voice streams. On the training side, we combined SI-SDR and multi-resolution STFT losses to give the model both waveform-level and spectral supervision.\n", + "\n", + "To go beyond overfitting, we'd need to make a few adjustments:\n", + "- 2 second segments don't give quite enough context: the original paper uses 4 second audio segments. \n", + "- Data augmentation is essential. In each batch, we can pick a subset of the voice parts to include, and learn to separate just their sum rather than the full mixture. We can also adjust how load each voice part is, or add in convolutional reverb to mimick different acoustics.\n", + "- Our model dimensions are simplified compared to the real SepReformer. The paper uses a stride of 4 (vs. our 8), giving twice as many latent frames and finer time resolution. It projects the 256-channel encoder output down to 128 dimensions before the transformer blocks, and runs 4 separator stages (vs. our 2). It also interleaves local convolutional blocks (kernel size 65) with global multi-head attention rather than using pure dual-path transformers, and adds dropout (0.05) throughout. Scaling up to these settings would improve separation quality at the cost of more compute." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,md", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/audio.md b/docs_nnx/guides/audio.md new file mode 100644 index 000000000..f2e7b16b1 --- /dev/null +++ b/docs_nnx/guides/audio.md @@ -0,0 +1,579 @@ +--- +jupyter: + jupytext: + cell_metadata_filter: -all + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 +--- + +# Building a Choral Source Separator with SepReformer in JAX + +This tutorial demonstrates how to perform audio source separation using the `SepReformer` +architecture. The original paper can be found at https://arxiv.org/abs/2406.05983; the authors' full source code is at https://github.com/dmlguq456/SepReformer + +```python +import numpy as np +import jax +import jax.numpy as jnp +from flax import nnx +from jaxtyping import Float, Array +import soundfile as sf +``` + +## The Task + +Given a mono mixture waveform $x \in \mathbb{R}^T$, produce $N$ separated stem +waveforms $\hat{s}_1, \ldots, \hat{s}_N \in \mathbb{R}^T$ such that +$\sum_n \hat{s}_n \approx x$ and each $\hat{s}_n$ matches one isolated voice +track. + +We use the **JaCappella** corpus (35 a cappella songs) via Hugging Face. Each +song has 5 isolated stems — `lead_vocal`, `soprano`, `alto`, `tenor`, `bass` — +and the mixture is their sum. + +```python +import librosa +from pathlib import Path + +STEM_NAMES = ("lead_vocal", "soprano", "alto", "tenor", "bass") +SAMPLE_RATE = 44100 + +class JaCappellaDataset: + def __init__(self, root) -> None: + self.root = Path(root) + self.sample_rate = 44100 + self.songs = sorted( + song + for genre in self.root.iterdir() + if genre.is_dir() and not genre.name.startswith(".") + for song in genre.iterdir() + if song.is_dir() and not song.name.startswith(".")) + + def _load_wav(self, path: Path) -> Float[np.ndarray, "T"]: + """Load a wav file, resample if needed, return mono float32.""" + audio, sr = sf.read(path, dtype="float32", always_2d=True) + audio = audio[:, 0] # take first channel if stereo + if sr != self.sample_rate: + audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) + return audio + + def _load_stems(self, song_dir: Path) -> Float[np.ndarray, "N T"]: + """Load all available stems for a song.""" + stems = [] + for i, name in enumerate(STEM_NAMES): + path = song_dir / f"{name}.wav" + stems.append(self._load_wav(path) if path.exists() else np.zeros(0, dtype=np.float32)) + max_len = max(len(s) for s in stems) + result = np.zeros((5, max_len), dtype=np.float32) + for i, s in enumerate(stems): + result[i, :len(s)] = s + return result + + def __getitem__(self, idx: int) -> tuple[Float[np.ndarray, "T"], Float[np.ndarray, "N T"]]: + """Load full song.""" + song_dir = self.songs[idx] + stems = self._load_stems(song_dir) + return stems.sum(axis=0), stems + + def __len__(self) -> int: + return len(self.songs) + +dataset = JaCappellaDataset("data/jacappella") +``` + +We'll feed the dataset to our model using the `grain` library. + +```python +import grain + +SEG_SAMPLES = SAMPLE_RATE * 2 # 2-second segments +BATCH_SIZE = 1 + +def extract_segment(item, rng): + mixture, stems = item + T = mixture.shape[0] + if T <= SEG_SAMPLES: + pad = SEG_SAMPLES - T + return np.pad(mixture, (0, pad)), np.pad(stems, ((0, 0), (0, pad))) + start = rng.integers(0, T - SEG_SAMPLES) + return mixture[start:start + SEG_SAMPLES], stems[:, start:start + SEG_SAMPLES] + +def batch_to_jax(items): + mixtures = jnp.array(np.stack([m for m, _ in items])) # (B, T) + stems = jnp.array(np.stack([s for _, s in items])) # (B, N, T) + return mixtures, stems + +loader = (grain.MapDataset.source(dataset) + .seed(0).shuffle() + .random_map(extract_segment) + .batch(BATCH_SIZE, drop_remainder=True, batch_fn=batch_to_jax)) +``` + +To make sure the data is loading correctly, we can sample a batch and log it to Tensorboard. + +```python +from tensorboardX import SummaryWriter + +mixture, stems = next(iter(loader)) +mixture = np.array(mixture[0]) # (T,) +stems = np.array(stems[0]) # (N, T) +writer = SummaryWriter("samples") +peak = np.max(np.abs(mixture)) +scale = 0.99 / peak if peak > 0 else 1.0 +writer.add_audio( + f"mixture", + mixture * scale, + sample_rate=dataset.sample_rate) +for n in range(stems.shape[0]): + writer.add_audio( + f"stem/{n}", + stems[n] * scale, + sample_rate=dataset.sample_rate) +writer.close() +``` + +```python +from IPython.display import display, HTML +display(HTML(open("mixture.html").read())) +display(HTML(open("stem.html").read())) +``` + +## Architecture + + +At a high level: +- The input waveform is encoded to a latent space using convolutional and transformer layers. +- The result gets split into separate pieces for each voice part +- Each piece is decoded back to a waveform by the same stack of transformer and convolutional layers. + +$$x \xrightarrow{\text{Conv}} h \xrightarrow{\text{Enc blocks}} h \xrightarrow{\text{Split}} \{h_n\} \xrightarrow{\text{Dec blocks}} \xrightarrow{\text{ConvT}} \{\hat{s}_n\}$$ + +## Convolutional Layers: Waveform → Latent Frames and Back + +A strided 1-D convolution converts the raw waveform into a sequence of latent +frames. With kernel $K$ and stride $S$: + +$$L = \left\lfloor \frac{T - K}{S} \right\rfloor + 1$$ + +The encoder applies a GELU after the convolution: + +$$h = \text{GELU}(W_\text{enc} * x), \quad h \in \mathbb{R}^{L \times C}$$ + +```python +class Encoder(nnx.Module): + def __init__(self, out_channels: int, *, rngs: nnx.Rngs): + self.conv = nnx.Conv(1, out_channels, kernel_size=(16,), strides=(8,), + padding='VALID', rngs=rngs) + + def __call__(self, x: Float[Array, "B T"]) -> Float[Array, "B L C"]: + h = self.conv(x[..., None]) # (B, T, 1) -> (B, L, C) + return jax.nn.gelu(h) +``` + +```python +class Decoder(nnx.Module): + def __init__(self, in_channels: int, *, rngs: nnx.Rngs): + self.conv_t = nnx.ConvTranspose(in_channels, 1, kernel_size=(16,), strides=(8,), + padding='VALID', rngs=rngs) + + def __call__(self, h: Float[Array, "B L C"]) -> Float[Array, "B T"]: + out = self.conv_t(h) # (B, L, C) -> (B, T, 1) + return out[..., 0] # (B, T) +``` + +Default: $K=16$, $S=8$, $C=256$. At 44.1 kHz a 2-second clip becomes +$L \approx 11{,}025$ frames. + + +## SNAKE Activation + +**SNAKE** (Liu et al., 2022) adds a learnable sinusoidal term that preserves the +periodic structure present in harmonic audio: + +$$\text{SNAKE}(x; \alpha) = x + \frac{1}{\alpha}\sin^2(\alpha x)$$ + +$\alpha$ is initialized to $\mathbf{1}$ (small perturbation at startup) and +learned per channel. SNAKE is used inside every feed-forward sub-layer in the +transformer blocks. Because $\alpha$ broadcasts along all leading dimensions, +`Snake` works with any number of batch axes. + +```python +class Snake(nnx.Module): + def __init__(self, features: int, *, rngs: nnx.Rngs): + self.alpha = nnx.Param(jnp.ones(features)) + + def __call__(self, x: Float[Array, "B F"]) -> Float[Array, "B F"]: + a = self.alpha.value + return x + (1.0 / (a + 1e-6)) * jnp.sin(a * x) ** 2 +``` + +```python +def feedforward(dim, ff_dim, rngs: nnx.Rngs): + return nnx.Sequential( + nnx.Linear(dim, ff_dim, rngs=rngs), + Snake(ff_dim, rngs=rngs), + nnx.Linear(ff_dim, dim, rngs=rngs)) +``` + +## Rotary Positional Embeddings + +Rotary positional embeddings (RoPE) encode position by rotating pairs of +features through position-dependent angles. This gives the transformer +translation-equivariant relative position information without adding +explicit position tokens. + +Flax provides `nnx.RoPE`, which precomputes cosine and sine frequency +tables once and stores them as module state. To use it with +`nnx.MultiHeadAttention`, pass `nnx.dot_product_attention_with_rope` +(with the `rope` argument partially applied) as the `attention_fn`: + +```python +import functools +``` + +```python +class TransformerBlock(nnx.Module): + def __init__( + self, dim: int, num_heads: int, ff_dim: int, + max_seq_len: int = 2048, *, rngs: nnx.Rngs + ): + self.norm1 = nnx.LayerNorm(dim, rngs=rngs) + self.norm2 = nnx.LayerNorm(dim, rngs=rngs) + head_dim = dim // num_heads + rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_seq_len) + self.attn = nnx.MultiHeadAttention( + num_heads=num_heads, + in_features=dim, + qkv_features=dim, + attention_fn=functools.partial( + nnx.dot_product_attention_with_rope, rope=rope), + decode=False, + rngs=rngs, + ) + self.ff = feedforward(dim, ff_dim, rngs=rngs) + self.scale1 = nnx.Param(jnp.full(dim, 1e-4)) + self.scale2 = nnx.Param(jnp.full(dim, 1e-4)) + + def __call__(self, x: Float[Array, "B S D"]) -> Float[Array, "B S D"]: + normed = self.norm1(x) + attn_out = self.attn(normed) + x = x + self.scale1[...] * attn_out + x = x + self.scale2[...] * self.ff(self.norm2(x)) + return x +``` + +## Stacking Transformer Layers: The Dual-Path Approach + +Full self-attention over $L \approx 11{,}000$ frames costs $O(L^2)$. The +dual-path trick (Luo & Mesgarani, 2020) splits this into two $O(L \cdot K)$ +passes: + +1. **Intra-chunk** — reshape to $(\ldots, M, K, C)$; each of the $M$ chunks + attends within itself. Captures local patterns. Cost: $O(M \cdot K^2)$. +2. **Inter-chunk** — swap to $(\ldots, K, M, C)$; each time-slot attends + across all $M$ chunks. Propagates global pitch/rhythm. Cost: $O(K \cdot M^2)$. + +Because `TransformerBlock` accepts `(B, S, D)`, the extra chunk axis +becomes just another batch dimension — no explicit `vmap` is needed. + +```python +class DualPathBlock(nnx.Module): + def __init__( + self, + dim: int, + num_heads: int, + ff_dim: int, + chunk_size: int = 64, + *, + rngs: nnx.Rngs, + ): + self.intra_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs) + self.inter_block = TransformerBlock(dim, num_heads, ff_dim, rngs=rngs) + self.chunk_size = chunk_size + + def __call__(self, x: Float[Array, "B L C"]) -> Float[Array, "B L C"]: + B_shape, L, C = x.shape + K = self.chunk_size + + # Pad to multiple of chunk_size + pad_len = (K - L % K) % K + if pad_len > 0: + x = jnp.pad(x, [(0, 0)] * len(batch_shape) + [(0, pad_len), (0, 0)]) + + L_padded = x.shape[-2] + M = L_padded // K + + # (B, L, C) -> (B, M, K, C) + chunks = x.reshape(B_shape, M, K, C) + + # Intra-chunk: TransformerBlock sees (B, M) as batch, K as seq + chunks = self.intra_block(chunks) + + # Inter-chunk: swap M <-> K, attend, swap back + inter = jnp.swapaxes(chunks, -3, -2) # (B, K, M, C) + inter = self.inter_block(inter) + chunks = jnp.swapaxes(inter, -3, -2) # (B, M, K, C) + + out = chunks.reshape(B_shape, L_padded, C) + return out[..., :L, :] +``` + +## Splitting into Speaker Streams + +After the shared encoder blocks, a `SplitLayer` expands $(\ldots, L, C)$ into +$(\ldots, N, L, C)$. Splitting here lets each of the $N$ reconstruction stacks +specialize on one speaker while sharing parameters — the subsequent +`DualPathBlock` and `Decoder` layers simply treat the new $N$ axis as an +additional batch dimension. + +A GLU gate first refines the shared features before expanding: + +$$g, v = \text{split}(W_1 h), \quad h' = \sigma(g) \odot v, \quad \text{streams} = W_2 h' \;\text{reshaped to}\; (\ldots, N, L, C)$$ + +```python +class SplitLayer(nnx.Module): + def __init__(self, dim: int, num_stems: int, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(dim, dim * 2, rngs=rngs) + self.linear2 = nnx.Linear(dim, dim * num_stems, rngs=rngs) + self.num_stems = num_stems + + def __call__(self, x: Float[Array, "B L C"]) -> Float[Array, "B N L C"]: + h = self.linear1(x) # (B, L, 2C) + gate, val = jnp.split(h, 2, axis=-1) + h = jax.nn.sigmoid(gate) * val # (B, L, C) + h = self.linear2(h) # (B, L, N*C) + B_shape, L_dim, _ = h.shape + C = x.shape[-1] + h = h.reshape(B_shape, L_dim, self.num_stems, C) # (B, L, N, C) + return jnp.swapaxes(h, -3, -2) # (B, N, L, C) +``` + +## Full Forward Pass + +After `SplitLayer` produces $(\ldots, N, L, C)$, the reconstruction blocks +and decoder see $(B, N)$ as batch dimensions. The entire forward pass +runs without any explicit `vmap`. + +```python +class SepReformer(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + num_sep_blocks = 2 + num_rec_blocks = 2 + dim = 256 + num_heads = 8 + ff_dim = 1024 + chunk_size = 64 + + self.encoder = Encoder(dim, rngs=rngs) + self.decoder = Decoder(dim, rngs=rngs) + self.split = SplitLayer(dim, 5, rngs=rngs) + self.sep_blocks = [ + DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs) + for _ in range(num_sep_blocks) + ] + self.rec_blocks = [ + DualPathBlock(dim, num_heads, ff_dim, chunk_size, rngs=rngs) + for _ in range(num_rec_blocks) + ] + + def __call__(self, x: Float[Array, "B T"]) -> Float[Array, "B N T"]: + h = self.encoder(x) # (B, L, C) + for block in self.sep_blocks: + h = block(h) # (B, L, C) + stems = self.split(h) # (B, N, L, C) + for block in self.rec_blocks: + stems = block(stems) # (B, N, L, C) — N is a batch dim + out = self.decoder(stems) # (B, N, T') + # trim / pad to original length + T = x.shape[-1] + if out.shape[-1] > T: + out = out[..., :T] + elif out.shape[-1] < T: + pad_width = [(0, 0)] * (out.ndim - 1) + [(0, T - out.shape[-1])] + out = jnp.pad(out, pad_width) + return out # (B, N, T) + +model = SepReformer(rngs=nnx.Rngs(0)) +``` + + +## Loss Functions + +Supervising a source separator is not straightforward. A plain mean-squared +error (MSE) in the waveform domain penalises tiny timing offsets and global +loudness differences equally, so the model spends capacity chasing irrelevant +phase shifts rather than learning to separate voices. We instead use two +complementary objectives — one waveform-domain and one spectral — that together +give stable, perceptually meaningful gradients. + +### SI-SDR + +SI-SDR projects the estimate onto the target and reports the energy ratio in dB. +It is invariant to global loudness, which matters for a cappella where voices +differ widely in level: + +$$\hat{s}_\text{tgt} = \frac{\langle \hat{s}, s \rangle}{\|s\|^2} s, \qquad \text{SI-SDR} = 10\log_{10}\frac{\|\hat{s}_\text{tgt}\|^2}{\|\hat{s} - \hat{s}_\text{tgt}\|^2}$$ + +The projection step removes any DC offset before computing the ratio, so a +perfectly separated signal that is merely scaled up or down still scores the +maximum possible value. In practice, SI-SDR values above $+10\ \text{dB}$ +indicate clearly separated sources; below $0\ \text{dB}$ the estimate is +dominated by leakage from other voices. We negate it to turn maximisation into +minimisation. + +```python +def si_sdr(estimate: Float[Array, "T"], target: Float[Array, "T"], eps: float = 1e-8) -> Float[Array, ""]: + estimate = estimate - jnp.mean(estimate) + target = target - jnp.mean(target) + dot = jnp.sum(estimate * target) + s_target = (dot / (jnp.sum(target ** 2) + eps)) * target + e_noise = estimate - s_target + return 10.0 * jnp.log10(jnp.sum(s_target ** 2) / (jnp.sum(e_noise ** 2) + eps) + eps) +``` + +### Multi-Resolution STFT Loss + +SI-SDR is blind to spectral texture: two signals can have the same SI-SDR yet +sound very different if one has unnatural resonances or missing harmonics. +Adding a frequency-domain term at three FFT scales $\{512, 1024, 2048\}$ +addresses this at multiple time-frequency resolutions simultaneously. + +A small FFT ($512$) gives sharp time resolution — useful for detecting onset +smearing — while a large FFT ($2048$) gives fine frequency resolution — useful +for resolving individual harmonics in a choir. Using all three averages out +the inherent time-frequency tradeoff of any single STFT. + +Each scale contributes two terms: + +- **Spectral convergence** — the Frobenius-norm distance between magnitude + spectrograms, normalised by the target energy. This drives the gross shape + of the spectrum towards the reference. +- **Log-magnitude distance** — the mean absolute difference on a log scale. + Because human pitch perception is logarithmic, this term penalises errors in + quiet harmonics just as strongly as errors in loud ones. + +$$\mathcal{L}_\text{STFT} = \frac{1}{3}\sum_\text{scale}\left(\underbrace{\frac{\||S| - |\hat{S}|\|_F}{\||S|\|_F}}_{\text{spectral convergence}} + \underbrace{\text{mean}|\log|S| - \log|\hat{S}||}_{\text{log-magnitude}}\right)$$ + +```python +def stft_mag(x: Float[Array, "T"], fft_size: int, hop: int, win_size: int) -> Float[Array, "F K"]: + window = jnp.hanning(win_size) + x_pad = jnp.pad(x, (fft_size // 2, fft_size // 2)) + n_frames = (len(x_pad) - win_size) // hop + 1 + idx = jnp.arange(win_size)[None, :] + jnp.arange(n_frames)[:, None] * hop + frames = x_pad[idx] * window + return jnp.abs(jnp.fft.rfft(frames, n=fft_size, axis=-1)).T # (F, K) + +def stft_loss_single(est: Float[Array, "T"], tgt: Float[Array, "T"], fft_size: int, hop: int, win: int) -> Float[Array, ""]: + em, tm = stft_mag(est, fft_size, hop, win), stft_mag(tgt, fft_size, hop, win) + sc = jnp.linalg.norm(tm - em) / (jnp.linalg.norm(tm) + 1e-8) + lm = jnp.mean(jnp.abs(jnp.log(em + 1e-8) - jnp.log(tm + 1e-8))) + return sc + lm + +def mr_stft_loss(est: Float[Array, "T"], tgt: Float[Array, "T"]) -> Float[Array, ""]: + scales = [(512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)] + return sum(stft_loss_single(est, tgt, *s) for s in scales) / len(scales) +``` + +### Composite Loss + +The final objective combines the two terms, with the STFT loss weighted at +$0.5$ so that SI-SDR — which operates directly in the waveform domain and +carries the strongest perceptual signal — dominates early in training. The +STFT term then fills in spectral detail that SI-SDR cannot see. Both terms are +averaged across the $N$ stems before being averaged across the batch. + +Because the model already handles the batch dimension, `loss_fn` calls the +model once on the full `(B, T)` mixture and then vmaps the per-stem loss +over the batch. + +```python +def composite_loss(estimates: Float[Array, "N T"], targets: Float[Array, "N T"]) -> Float[Array, ""]: + def pair(est, tgt): + return -si_sdr(est, tgt) + 0.5 * mr_stft_loss(est, tgt) + return jnp.mean(jax.vmap(pair)(estimates, targets)) + +def loss_fn(model, mixture: Float[Array, "B T"], targets: Float[Array, "B N T"]) -> Float[Array, ""]: + estimates = model(mixture) # (B, N, T) + return jnp.mean(jax.vmap(composite_loss)(estimates, targets)) +``` + +## Overfitting on JaCappella + +Before training on the full corpus, we overfit on a single batch. This is a +fast sanity check: if the model cannot memorise even one example, something is +wrong with the architecture, the loss, or the data pipeline. It is much +cheaper to discover this now than after a multi-hour training run. + +### Audio Logging + +The loss curve tells you the model is learning, but it does not tell you *what* +it is learning. Listening to the actual estimates at checkpoints is +irreplaceable: you can hear immediately whether the model is separating voices, +producing silence, or emitting noise. `log_audio_samples` writes one batch of +audio to TensorBoard — the raw mixture, each ground-truth stem, and the +corresponding model estimate — all normalised to a peak of $0.99$ so playback +levels are comparable across steps. + +```python +def log_audio_samples(model, loader, writer, global_step): + mixture, stems = next(iter(loader)) + mix_np = np.array(mixture[0]) + stems_np = np.array(stems[0]) + est_np = np.array(model(mixture[0:1])[0]) # keep batch dim, then index out + scale = 0.99 / (np.max(np.abs(mix_np)) + 1e-8) + writer.add_audio("mixture", mix_np * scale, global_step, sample_rate=SAMPLE_RATE) + for n in range(stems_np.shape[0]): + writer.add_audio(f"true/{n}", stems_np[n] * scale, global_step, sample_rate=SAMPLE_RATE) + for n in range(est_np.shape[0]): + est_scale = 0.99 / (np.max(np.abs(est_np[n])) + 1e-8) + writer.add_audio(f"estimate/{n}", est_np[n] * est_scale, global_step, sample_rate=SAMPLE_RATE) +``` + +### Optimizer and Training Loop + +We use AdamW with a global gradient-norm clip of $1.0$. Clipping is important +here because early in training the split layer and decoder produce near-random +outputs, which can generate very large gradients through the SI-SDR loss. +Weight decay of $10^{-2}$ provides mild regularisation to prevent any single +stem stream from collapsing to zero. + +The loop runs for 50 epochs, logging the scalar loss every 50 steps and +uploading a fresh set of audio samples at the end of each epoch. You can +monitor progress in TensorBoard with `tensorboard --logdir runs/overfit`. + +```python +@nnx.jit +def step(model, optimizer, mixture, targets): + loss, grads = nnx.value_and_grad(loss_fn)(model, mixture, targets) + optimizer.update(model, grads) + return loss +``` + +```python +import optax + +optimizer = nnx.Optimizer(model, optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(3e-4, weight_decay=1e-2))) + +writer = SummaryWriter("runs/overfit") +for epoch in range(200): + for mixture, targets in loader[:1]: + loss = step(model, optimizer, mixture, targets) + writer.add_scalar("loss", float(loss), epoch) + if epoch % 20 == 0: + log_audio_samples(model, loader, writer, epoch) +writer.close() +``` + +## Summary + +This tutorial walked through building a choral source separator from scratch in Flax NNX. We covered the key ingredients: a convolutional encoder/decoder pair for moving between waveforms and latent frames, SNAKE activations for preserving harmonic structure, dual-path transformer blocks for efficient long-sequence attention, and a GLU-gated split layer for dividing shared representations into per-voice streams. On the training side, we combined SI-SDR and multi-resolution STFT losses to give the model both waveform-level and spectral supervision. + +To go beyond overfitting, we'd need to make a few adjustments: +- 2 second segments don't give quite enough context: the original paper uses 4 second audio segments. +- Data augmentation is essential. In each batch, we can pick a subset of the voice parts to include, and learn to separate just their sum rather than the full mixture. We can also adjust how load each voice part is, or add in convolutional reverb to mimick different acoustics. +- Our model dimensions are simplified compared to the real SepReformer. The paper uses a stride of 4 (vs. our 8), giving twice as many latent frames and finer time resolution. It projects the 256-channel encoder output down to 128 dimensions before the transformer blocks, and runs 4 separator stages (vs. our 2). It also interleaves local convolutional blocks (kernel size 65) with global multi-head attention rather than using pure dual-path transformers, and adds dropout (0.05) throughout. Scaling up to these settings would improve separation quality at the cost of more compute. diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 78a550411..fd02e8e39 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -115,6 +115,8 @@ from .nn.attention import MultiHeadAttention as MultiHeadAttention from .nn.attention import combine_masks as combine_masks from .nn.attention import dot_product_attention as dot_product_attention +from .nn.attention import dot_product_attention_with_rope as dot_product_attention_with_rope +from .nn.attention import RoPE as RoPE from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask from .nn.recurrent import RNNCellBase as RNNCellBase diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..01dbe56a5 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -319,6 +319,128 @@ def reshape_4d(x): return out +class RoPE(Module): + """Rotary Position Embedding (RoPE). + + Precomputes and stores cosine and sine frequency tensors for rotary + positional encoding as described in "RoFormer: Enhanced Transformer + with Rotary Position Embedding" (https://arxiv.org/abs/2104.09864). + + Example usage:: + + >>> from flax import nnx + >>> import functools + >>> rope = nnx.RoPE(embedding_size=64, max_seq_len=128) + >>> layer = nnx.MultiHeadAttention( + ... num_heads=8, in_features=512, qkv_features=512, + ... attention_fn=functools.partial( + ... nnx.dot_product_attention_with_rope, rope=rope), + ... decode=False, rngs=nnx.Rngs(0)) + + Args: + theta: Base frequency for sinusoidal functions. Default is 10000. + embedding_size: Optional. Size of each embedding vector (i.e. head_dim). Must be + even. If provided together with ``max_seq_len``, sin/cos tables are + precomputed and cached as module Variables. + max_seq_len: Optional. Maximum sequence length for precomputed frequencies. Must + be provided together with ``embedding_size``. + """ + + def __init__( + self, + theta: float = 10000.0, + embedding_size: int | None = None, + max_seq_len: int | None = None, + ): + self.theta = theta + if sum([embedding_size is None, max_seq_len is None]) % 2 != 0: + raise ValueError('Either both `embedding_size` and `max_seq_len` or none of them must be provided.') + if embedding_size is not None and max_seq_len is not None: + positions = jnp.arange(float(max_seq_len)) + sin, cos = self._rotation_for(embedding_size, positions) + self.cached_sin = nnx.Variable(sin) + self.cached_cos = nnx.Variable(cos) + + @staticmethod + def _rotate_half(x: Array) -> Array: + d_2 = x.shape[-1] // 2 + return jnp.concatenate([-x[..., d_2:], x[..., :d_2]], axis=-1) + + def _rotation_for(self, embedding_size: int, input_positions: Array): + seq_len = input_positions.shape[-1] + if embedding_size <= 0 or embedding_size % 2 != 0: + raise ValueError('`embedding_size` must be positive and even.') + freqs = 1.0 / ( + self.theta ** (jnp.arange(0.0, embedding_size, 2) / embedding_size) + ) + freqs_outer = jnp.outer(input_positions, freqs) + return jnp.sin(freqs_outer), jnp.cos(freqs_outer) + + def __call__(self, x: Array, input_positions: Array | None = None) -> Array: + """Apply rotary positional encoding. + + Args: + x: Input array of shape ``(..., seq_length, embedding_size)``. + input_positions: Optional position array of shape ``(seq_len,)``. + If ``None`` and precomputed tables exist (i.e. ``embedding_size`` + and ``max_seq_len`` were given at init), the cached tables are + used. Otherwise positions default to ``jnp.arange(seq_len)``. + + Returns: + Array of same shape with RoPE applied. + """ + seq_len, embedding_size = x.shape[-2:] + if input_positions is None and hasattr(self, 'cached_sin'): + sin = self.cached_sin.value[:seq_len] + cos = self.cached_cos.value[:seq_len] + else: + if input_positions is None: + input_positions = jnp.arange(float(seq_len)) + elif input_positions.shape[-1] != seq_len: + raise ValueError('`input_positions` must have the same seq_len as `x`.') + sin, cos = self._rotation_for(embedding_size, input_positions) + return x * jnp.tile(cos, (1, 2)) + self._rotate_half(x) * jnp.tile(sin, (1, 2)) + + +def dot_product_attention_with_rope( + query: Array, + key: Array, + value: Array, + *, + rope: RoPE, + input_positions: Array | None = None, + **kwargs, +) -> Array: + """Dot-product attention with Rotary Position Embedding applied to query and key. + + This function has the same signature as :func:`dot_product_attention` + (plus a ``rope`` keyword argument) and can be used as the ``attention_fn`` + of :class:`MultiHeadAttention` via ``functools.partial``:: + + attention_fn=functools.partial(dot_product_attention_with_rope, rope=rope) + + Args: + query: queries of shape ``[batch..., q_length, num_heads, head_dim]``. + key: keys of shape ``[batch..., kv_length, num_heads, head_dim]``. + value: values of shape ``[batch..., kv_length, num_heads, v_dim]``. + rope: A :class:`RoPE` module instance. + **kwargs: Additional keyword arguments forwarded to + :func:`dot_product_attention`. + + Returns: + Output of shape ``[batch..., q_length, num_heads, v_dim]``. + """ + # query/key: [batch..., seq_length, num_heads, head_dim] + # RoPE.__call__ expects (..., seq_length, head_dim), so vmap over heads. + if input_positions is None: + apply = jax.vmap(rope, in_axes=-2, out_axes=-2) + else: + apply = jax.vmap(rope, in_axes=(-2, -1), out_axes=(-2, -1)) + query = apply(query, input_positions) + key = apply(key, input_positions) + return dot_product_attention(query, key, value, **kwargs) + + class MultiHeadAttention(Module): """Multi-head attention. @@ -588,6 +710,7 @@ def __call__( is_causal=False, out_sharding = None, qkv_sharding = None, + input_positions: Array | None = None ): """Applies multi-head dot product attention on the input data. @@ -624,6 +747,9 @@ def __call__( the output linear layer for the output arrays. qkv_sharding: Optional sharding specification to pass to the QKV linear layers for the output arrays. + input_positions: Optional position indices of shape + ``[batch_sizes..., length]`` forwarded to the ``attention_fn``. + Used by :func:`dot_product_attention_with_rope` for RoPE. Returns: output of shape `[batch_sizes..., length, features]`. @@ -736,6 +862,9 @@ def __call__( dropout_rng = None # apply attention + attn_kwargs = {} + if input_positions is not None: + attn_kwargs["input_positions"] = input_positions x = self.attention_fn( query, key, @@ -748,7 +877,8 @@ def __call__( dtype=self.dtype, precision=self.precision, module=self if sow_weights else None, - is_causal=is_causal + is_causal=is_causal, + **attn_kwargs ) # back to the original inputs dimensions out = self.out(x, out_sharding=out_sharding) diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 167fbf8cf..9d3eac4b9 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -383,6 +383,42 @@ def _run(m): np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3) +class TestRoPE(absltest.TestCase): + + def test_relative_position_invariance(self): + """Dot product of RoPE-rotated vectors depends only on relative position. + + For any offset d, should equal + . + """ + head_dim = 64 + max_len = 128 + rope = nnx.RoPE(embedding_size=head_dim, max_seq_len=max_len) + + k1, k2 = jax.random.split(jax.random.key(7)) + q_vec = jax.random.normal(k1, (head_dim,)) + k_vec = jax.random.normal(k2, (head_dim,)) + + # Place q at position i and k at position j, then shift both by d. + i, j = 5, 12 + d = 37 + + def rope_at(vec, pos): + # Build a dummy sequence long enough, apply RoPE, extract one position. + seq = jnp.zeros((pos + 1, head_dim)).at[pos].set(vec) + return rope(seq)[pos] + + q_rot = rope_at(q_vec, i) + k_rot = rope_at(k_vec, j) + dot_original = jnp.dot(q_rot, k_rot) + + q_rot_shifted = rope_at(q_vec, i + d) + k_rot_shifted = rope_at(k_vec, j + d) + dot_shifted = jnp.dot(q_rot_shifted, k_rot_shifted) + + np.testing.assert_allclose(dot_original, dot_shifted, atol=1e-5) + + if __name__ == '__main__': absltest.main()