diff --git a/.github/workflows/marimo_export_md.yml b/.github/workflows/marimo_export_md.yml new file mode 100644 index 00000000..130601ec --- /dev/null +++ b/.github/workflows/marimo_export_md.yml @@ -0,0 +1,76 @@ +name: marimo export markdown + +# On every push that changes a marimo notebook under examples/marimo/, export +# each notebook to a peer Markdown file (notebook.py -> notebook.md) and commit +# the result back to the branch, so the rendered Markdown always tracks the +# notebook. +# +# Loop safety (three independent guards): +# 1. Pushes made with GITHUB_TOKEN do not trigger new workflow runs — a +# GitHub Actions built-in, and the primary protection here. +# 2. The trigger watches only *.py; this job only ever commits *.md. +# 3. The commit message carries [skip ci]. +# +# marimo is pinned so exports are byte-deterministic (the front matter records +# the marimo version), which means an unchanged notebook never produces a +# spurious commit. Bump MARIMO_VERSION to refresh all exports on the next push. + +on: + push: + paths: + - 'examples/marimo/**/*.py' + +permissions: + contents: write + +concurrency: + group: marimo-export-md-${{ github.ref }} + cancel-in-progress: true + +env: + MARIMO_VERSION: "0.23.9" + +jobs: + export-md: + # Redundant with the GITHUB_TOKEN protection above, but keeps things safe + # if someone later swaps in a personal access token. + if: github.actor != 'github-actions[bot]' + runs-on: ubuntu-latest + steps: + - name: Checkout branch + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install marimo + run: python -m pip install --quiet "marimo==${MARIMO_VERSION}" + + - name: Export marimo notebooks to Markdown + run: | + shopt -s globstar nullglob + for nb in examples/marimo/**/*.py; do + # Only real marimo notebooks construct marimo.App(...). + if grep -q 'marimo\.App(' "$nb"; then + echo "Exporting $nb -> ${nb%.py}.md" + marimo export md "$nb" -o "${nb%.py}.md" -f + fi + done + + - name: Commit and push if the Markdown changed + run: | + git config user.name 'github-actions[bot]' + git config user.email '41898282+github-actions[bot]@users.noreply.github.com' + # Only Markdown peers are generated, so staging the tree captures + # exactly the exported files (the notebooks themselves are untouched). + git add -A examples/marimo + if git diff --cached --quiet; then + echo "Markdown already up to date." + else + git commit -m "docs: export marimo notebook(s) to Markdown [skip ci]" + git push origin "HEAD:${{ github.ref_name }}" + fi diff --git a/.github/workflows/marimo_molab.yml b/.github/workflows/marimo_molab.yml new file mode 100644 index 00000000..d9dbb3e9 --- /dev/null +++ b/.github/workflows/marimo_molab.yml @@ -0,0 +1,144 @@ +name: marimo molab links + +# Posts — and keeps updated — a PR comment linking each modified marimo +# notebook to molab (https://molab.marimo.io), which runs any public marimo +# notebook on GitHub in a hosted environment with no local setup. +# +# Security note: this uses `pull_request_target` so the comment can also be +# posted on PRs from forks (a plain `pull_request` event gives fork PRs a +# read-only token that cannot comment). The job NEVER checks out or executes +# PR code — it only reads changed-file metadata and file contents as text +# through the API, then posts a comment. Do not add a checkout of the PR head +# or run any PR-provided code in this workflow. + +on: + pull_request_target: + types: [opened, synchronize, reopened] + paths: + - '**/*.py' + +permissions: + contents: read + pull-requests: write + +jobs: + molab-links: + runs-on: ubuntu-latest + steps: + - name: Comment molab links for modified marimo notebooks + uses: actions/github-script@v7 + with: + script: | + const pr = context.payload.pull_request; + const headOwner = pr.head.repo.owner.login; + const headRepo = pr.head.repo.name; + const headSha = pr.head.sha; // pin content detection to this PR revision + const headRef = pr.head.ref; // branch name for the (auto-tracking) links + const marker = ''; + + // 1. List the files changed in this PR. + const files = await github.paginate(github.rest.pulls.listFiles, { + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number, + per_page: 100, + }); + + // 2. Keep added/modified .py files and decide whether each is a + // marimo notebook by inspecting its content (never executing it). + // Every marimo notebook constructs `marimo.App(...)`. + const isMarimo = /\bmarimo\.App\s*\(/; + const notebooks = []; + for (const f of files) { + if (f.status === 'removed') continue; + if (!f.filename.endsWith('.py')) continue; + try { + const res = await github.rest.repos.getContent({ + owner: headOwner, + repo: headRepo, + path: f.filename, + ref: headSha, + }); + if (!res.data.content) { + core.warning(`Skipping ${f.filename}: content not inlined (file too large?).`); + continue; + } + const content = Buffer.from(res.data.content, res.data.encoding).toString('utf8'); + if (isMarimo.test(content)) notebooks.push(f.filename); + } catch (err) { + core.warning(`Could not read ${f.filename}: ${err.message}`); + } + } + + // 3. Find any prior comment so we update it in place instead of + // posting a new one on every push. + const comments = await github.paginate(github.rest.issues.listComments, { + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + per_page: 100, + }); + const existing = comments.find(c => c.body && c.body.includes(marker)); + + // 4. No marimo notebooks: clear a stale comment if present, else exit. + if (notebooks.length === 0) { + if (existing) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: existing.id, + body: `${marker}\n_No marimo notebooks in the current changes._`, + }); + } + core.info('No marimo notebooks found; nothing to link.'); + return; + } + + // 5. Build the comment. Links use the branch ref, not a commit + // SHA, so they always point at the latest revision without the + // comment needing an update on every push. GitHub resolves + // multi-segment (slashed) branch names in `blob//`, + // and molab fetches from GitHub, so slashed branches are fine. + const rows = notebooks.map((path) => { + // The `/server` suffix opens the notebook in a hosted runtime; + // without it molab shows a static, non-runnable preview. + const url = `https://molab.marimo.io/github/${headOwner}/${headRepo}/blob/${headRef}/${path}/server`; + return `| \`${path}\` | [![Open in molab](https://marimo.io/molab-shield.svg)](${url}) |`; + }).join('\n'); + + const body = [ + marker, + '### ▶️ Run the marimo notebook(s) in this PR', + '', + '[molab](https://molab.marimo.io) launches any public marimo notebook on ' + + 'GitHub in a hosted environment — no local setup required.', + '', + '| Notebook | molab |', + '| --- | --- |', + rows, + '', + `_Links track the head of \`${headRef}\`._`, + ].join('\n'); + + // 6. Upsert the comment (skip the write when nothing changed, so + // pushes that add no new notebook don't churn the comment). + if (existing) { + if (existing.body === body) { + core.info('Comment already up to date.'); + return; + } + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: existing.id, + body, + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body, + }); + } + core.info(`Linked ${notebooks.length} marimo notebook(s).`); diff --git a/examples/marimo/mnist-registry/README.md b/examples/marimo/mnist-registry/README.md new file mode 100644 index 00000000..cb3b89db --- /dev/null +++ b/examples/marimo/mnist-registry/README.md @@ -0,0 +1,95 @@ +# MNIST -> W&B Registry (marimo) + +A [marimo](https://marimo.io) notebook that trains a small CNN on MNIST with +PyTorch, tracks the run in Weights & Biases, saves the trained model as a W&B +Artifact, and links that Artifact to a collection in the **W&B Registry**. + +The notebook is the first marimo example in this repo and is intentionally +self-contained: dependencies are declared in a [PEP 723](https://peps.python.org/pep-0723/) +inline-script block at the top of `mnist_registry.py`, so [`uv`](https://docs.astral.sh/uv/) +can resolve them automatically. + +## Prerequisites + +- Python 3.10 or newer. +- A W&B account, authenticated one of two ways: run `wandb login` in your + shell before launching the notebook, or paste your key into the **W&B API + key** field in the form. Get your key from + [wandb.ai/authorize](https://wandb.ai/authorize). +- A W&B **Registry** must exist in your org, and your account needs at least + the **Member** role on it for the final linking step (linking an artifact is + a write action). The built-in Model registry is provisioned automatically in + newer orgs. If linking fails (for example, from a view-only seat), the + notebook surfaces a remediation message in the last Registry cell instead of + crashing. See + [configuring registry access](https://docs.wandb.ai/guides/registry/configure_registry/). +- GPU is optional. Defaults are tuned to finish in roughly two minutes on CPU. + +## Run + +Use `uvx` with marimo's sandbox mode — it creates an isolated virtual +environment from the inline dependencies in the notebook: + +```bash +uvx marimo edit mnist_registry.py --sandbox +``` + +marimo opens in your browser. Adjust hyperparameters in the form, then click +**Train model** to start the run. The run URL appears inline as soon as +training begins. + +If you prefer pip: + +```bash +pip install -r requirements.txt +marimo edit mnist_registry.py +``` + +The notebook is interactive-only by design: training is gated by a button +click, so `marimo run` renders the form but never starts training without +an explicit click. + +## What you get + +After a successful run: + +- A W&B run with training and test metrics, gradient histograms (`wandb.watch`), + and up to 16 example test-set predictions logged as images. +- A model Artifact named `mnist-cnn-` of type `model` with metadata + for test accuracy, parameter count, dataset sizes, and the full + hyperparameter dict. Tagged with the `latest` alias. +- A version of that Artifact linked into the configured Registry collection + (default: `wandb-registry-model/MNIST Classifiers`). + +To consume the registered model from another script or notebook: + +```python +import wandb +api = wandb.Api() +art = api.artifact("wandb-registry-model/MNIST Classifiers:latest") +art.download() # writes mnist_cnn.pt under ./artifacts/ +``` + +## Design notes + +- **Training is gated by a button.** marimo cells re-run reactively when their + inputs change. Before the first click of **Train model**, slider changes do + not start a run. After a run completes, clicking **Train model** again + starts a new run with the current form values; the previous run finishes + cleanly first. +- **`wandb.run` finishes defensively** at the top of the training cell so + a second click of **Train model** does not nest runs in the same marimo + kernel. +- **`logged.wait()` runs** after `log_artifact` and before `link_artifact` + to avoid a race where the link tries to resolve a version that has not + finished committing server-side. +- **Registry failures soft-fail.** If `link_artifact` raises — usually + because the Registry does not exist in your org — the notebook + surfaces remediation guidance through `mo.callout` rather than aborting. + +## Reference + +The CNN architecture and training loop mirror +[`examples/pytorch/pytorch-cnn-mnist/main.py`](../../pytorch/pytorch-cnn-mnist/main.py). +The Registry linking pattern follows +[`colabs/wandb_registry/zoo_wandb.ipynb`](../../../colabs/wandb_registry/zoo_wandb.ipynb). diff --git a/examples/marimo/mnist-registry/mnist_registry.md b/examples/marimo/mnist-registry/mnist_registry.md new file mode 100644 index 00000000..b1c2077c --- /dev/null +++ b/examples/marimo/mnist-registry/mnist_registry.md @@ -0,0 +1,510 @@ +--- +title: MNIST -> W&B Registry +marimo-version: 0.23.9 +width: medium +header: |- + # /// script + # requires-python = ">=3.10" + # dependencies = [ + # "marimo>=0.9", + # "torch>=2.1", + # "torchvision>=0.16", + # "wandb>=0.18", + # "tqdm", + # ] + # /// + """Train an MNIST CNN with PyTorch, track the run with Weights & Biases, + and link the resulting model artifact to a W&B Registry collection. + + Run: + + uvx marimo edit mnist_registry.py --sandbox + + The notebook has three interactive cells: fill in the form, click **Train + model**, then read the results. Everything between the inputs and the button + runs as a single step, so one click trains, logs, saves, and registers. + """ +--- + +```python {.marimo hide_code="true"} +import marimo as mo + +mo.md( + """ + # MNIST -> W&B Run -> Registry + + ## What you will build + + - A **W&B run** with training and test metrics, gradient histograms, + and example test-set predictions logged as images. + - A **model Artifact** named `mnist-cnn-` of type `model`, + carrying metadata (test accuracy, parameter count, hyperparameters). + - A version of that Artifact **linked into a W&B Registry collection** + so it appears under registered models org-wide. + + ## Prerequisites + + - Authenticate with W&B one of two ways: run **`wandb login`** in + your shell before starting marimo, or paste your key into the + **W&B API key** field in the form below. Get your key from + [wandb.ai/authorize](https://wandb.ai/authorize). + - A W&B **team** to write the run to, set in the **W&B entity** field. + Accounts created after May 2024 have no personal entity, so the run + must go to a team — your username will not work as an entity. + - A **W&B Registry** must exist in your org, and your account needs at + least the **Member** role on it (linking an artifact is a write + action). The built-in Model registry is provisioned automatically in + newer orgs. If linking fails (for example, from a view-only seat), + the run still completes and the Registry step explains how to fix it. + - A GPU is optional. The defaults finish in about 2 minutes on CPU. + """ +) +``` + +```python {.marimo} +# Imports, device detection, and the input form all live in one cell: this +# is "everything up to collecting your inputs". It defines the form widgets +# but never reads their `.value` — marimo only makes a widget reactive when +# a *different* cell consumes it, which the training cell below does. +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +import wandb +from tqdm.auto import tqdm + +if torch.cuda.is_available(): + device = torch.device("cuda") + device_note = "CUDA GPU detected. Training will be fast." + device_kind = "success" +elif torch.backends.mps.is_available(): + device = torch.device("mps") + device_note = "Apple MPS detected. Training will run on the GPU." + device_kind = "success" +else: + device = torch.device("cpu") + device_note = ( + "No GPU detected. Training will run on CPU. With the default " + "hyperparameters this takes about 2 minutes." + ) + device_kind = "warn" + +epochs = mo.ui.slider(start=1, stop=10, step=1, value=3, label="Epochs") +batch_size = mo.ui.dropdown( + options=["32", "64", "128", "256"], value="64", label="Batch size" +) +lr = mo.ui.slider( + start=0.001, stop=0.1, step=0.001, value=0.01, label="Learning rate", show_value=True +) +momentum = mo.ui.slider( + start=0.0, stop=0.99, step=0.01, value=0.5, label="SGD momentum", show_value=True +) +seed = mo.ui.number(start=0, stop=99999, value=42, label="Random seed") + +project = mo.ui.text(value="marimo-mnist-registry", label="W&B project") +entity = mo.ui.text( + value="", label="W&B entity — a team you belong to (blank uses your default)" +) +run_name = mo.ui.text(value="", label="Run name (blank auto-generates)") +api_key = mo.ui.text( + value="", kind="password", label="W&B API key (blank uses your shell login)" +) + +registry_name = mo.ui.text(value="model", label="W&B Registry name") +collection_name = mo.ui.text(value="MNIST Classifiers", label="Registry collection") +link_to_registry = mo.ui.checkbox(value=True, label="Link artifact to Registry") + +form = mo.vstack( + [ + mo.md("### Training"), + mo.hstack([epochs, batch_size]), + mo.hstack([lr, momentum]), + seed, + mo.md("### W&B run"), + api_key, + mo.hstack([project, entity, run_name]), + mo.md("### Registry"), + mo.hstack([registry_name, collection_name, link_to_registry]), + ] +) + +mo.vstack( + [ + mo.callout( + mo.md(f"**Device:** `{device}` — {device_note}"), kind=device_kind + ), + mo.md( + "## Configure\n\nSet the hyperparameters and W&B targets, then click " + "**Train model** below. Changing a value here never starts a run on " + "its own — only the button does." + ), + form, + ] +) +``` + +```python {.marimo} +class Net(nn.Module): + """Small CNN: 2 conv layers (10, 20 filters, 5x5) + 2 FC (50, 10). + + Defined in its own cell so the training cell and the consume cell can + share it (marimo forbids defining the same name in two cells). + """ + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) +``` + +```python {.marimo} +# Everything the Train button triggers, in one cell — no reason to make you +# advance through a chain of output-less code blocks. Each milestone is +# streamed to the cell output with `mo.output.append` as it happens. +mo.stop( + not train_button.value, + mo.md( + "This cell runs the whole pipeline — start the run, train, log " + "metrics and example predictions, save the model Artifact, and link " + "it to the Registry. Click **Train model** below to run it." + ), +) + +config = { + "epochs": epochs.value, + "batch_size": int(batch_size.value), + "lr": lr.value, + "momentum": momentum.value, + "seed": seed.value, + "architecture": "CNN", + "dataset": "MNIST", +} +registry_name_v = registry_name.value.strip() +collection_name_v = collection_name.value.strip() + +# Authenticate. Finish any prior run first (marimo keeps the kernel alive +# across re-clicks). A key pasted into the form wins; otherwise fall back to +# ambient login (shell `wandb login`, WANDB_API_KEY, or netrc). The key is +# never written to the run config. +if wandb.run is not None: + wandb.finish() +if api_key.value: + wandb.login(key=api_key.value) + +torch.manual_seed(config["seed"]) + +try: + run = wandb.init( + project=project.value or None, + entity=entity.value or None, + name=run_name.value or None, + config=config, + job_type="train", + ) +except Exception as exc: # noqa: BLE001 - turn the raw traceback into guidance + mo.stop( + True, + mo.callout( + mo.md( + f"**Could not start the run.** `{exc}`\n\n" + f"An `entity ... not found` error means the **W&B entity** is " + f"not a team you can write to. Personal-username entities were " + f"removed for accounts created after 21 May 2024, so set the " + f"**W&B entity** field to one of your teams (find them in the " + f"left sidebar at [wandb.ai](https://wandb.ai))." + ), + kind="danger", + ), + ) +# Use `epoch` as the x-axis for train/test metrics in the W&B UI. +wandb.define_metric("epoch") +wandb.define_metric("train/*", step_metric="epoch") +wandb.define_metric("test/*", step_metric="epoch") +# Surface the run link right away so you can watch metrics stream live. +mo.output.append(mo.md(f"**Run started:** [`{run.name}`]({run.url})")) + +model = Net().to(device) +# `log="gradients"` is the standard choice; `log="all"` also logs parameter +# histograms at extra cost. +wandb.watch(model, log="gradients", log_freq=100) +optimizer = optim.SGD( + model.parameters(), lr=config["lr"], momentum=config["momentum"] +) + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] +) +train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform) +test_ds = datasets.MNIST("./data", train=False, download=True, transform=transform) +loader_kwargs = ( + {"num_workers": 2, "pin_memory": True} if device.type == "cuda" else {} +) +train_loader = DataLoader( + train_ds, batch_size=config["batch_size"], shuffle=True, **loader_kwargs +) +test_loader = DataLoader(test_ds, batch_size=1000, shuffle=False, **loader_kwargs) + +history = [] +best_acc = 0.0 +for epoch in range(1, config["epochs"] + 1): + model.train() + for batch_idx, (data, target) in enumerate( + tqdm(train_loader, desc=f"epoch {epoch}/{config['epochs']}") + ): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 50 == 0: + wandb.log({"train/loss": loss.item(), "epoch": epoch}) + + model.eval() + test_loss = 0.0 + correct = 0 + example_images = [] + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + # Pull up to 16 example predictions from the first batch. + while len(example_images) < 16 and len(example_images) < data.size(0): + j = len(example_images) + example_images.append( + wandb.Image( + data[j], + caption=f"pred={pred[j].item()} true={target[j].item()}", + ) + ) + + test_loss /= len(test_loader.dataset) + test_acc = correct / len(test_loader.dataset) + best_acc = max(best_acc, test_acc) + wandb.log( + { + "test/loss": test_loss, + "test/accuracy": test_acc, + "epoch": epoch, + "examples": example_images, + } + ) + history.append( + {"epoch": epoch, "test_loss": round(test_loss, 4), "test_acc": round(test_acc, 4)} + ) + +# Full-precision last-epoch accuracy; `history` rounds only for display. +final_acc = test_acc +mo.output.append( + mo.vstack( + [ + mo.md("### Training summary"), + mo.ui.table(history, selection=None), + mo.md(f"**Final test accuracy:** {final_acc:.2%}"), + ] + ) +) + +# Save the weights and log them as a model Artifact tagged `latest`. +model_path = "mnist_cnn.pt" +torch.save(model.state_dict(), model_path) +artifact = wandb.Artifact( + name=f"mnist-cnn-{run.id}", + type="model", + description=( + "Small CNN trained on MNIST. Architecture: 2 conv layers " + "(10 and 20 filters, 5x5 kernels) + 2 FC layers (50, 10)." + ), + metadata={ + "framework": "pytorch", + "architecture": "CNN", + "num_parameters": sum(p.numel() for p in model.parameters()), + "dataset": "MNIST", + "train_size": len(train_ds), + "test_size": len(test_ds), + "test_accuracy": final_acc, + "best_test_accuracy": best_acc, + "hyperparameters": dict(config), + }, +) +artifact.add_file(model_path) +logged = run.log_artifact(artifact, aliases=["latest"]) +# Block until the artifact has committed before linking, to avoid a race. +logged.wait() +mo.output.append(mo.md(f"**Artifact logged:** `{artifact.name}` (alias `latest`)")) + +# Link to the Registry, surfacing a remediation note instead of crashing. +if link_to_registry.value: + target_path = f"wandb-registry-{registry_name_v}/{collection_name_v}" + try: + run.link_artifact(artifact=logged, target_path=target_path) + mo.output.append( + mo.callout( + mo.md( + f"**Linked to Registry:** `{target_path}` — see " + f"[wandb.ai/registry](https://wandb.ai/registry)." + ), + kind="success", + ) + ) + except Exception as exc: # noqa: BLE001 - surface any failure to the reader + mo.output.append( + mo.callout( + mo.md( + f"**Registry link failed.** Target `{target_path}` — `{exc}`\n\n" + f"- Linking needs at least the **Member** role on the " + f"Registry. `view-only member cannot write to project` means " + f"your seat is view-only: the run and artifact succeed, but " + f"linking is blocked. An admin can grant access from the " + f"Registry **Members** settings, the Python SDK " + f"(`wandb.Api().registry(...)` then `add_member()` / " + f"`update_member()`), or SCIM (`PATCH /scim/Users/{{id}}` with " + f"`registryRoles`) — see " + f"https://docs.wandb.ai/guides/registry/configure_registry/. " + f"Or set **W&B entity** to a team in an org where you have " + f"Registry write access.\n" + f"- The Registry `{registry_name_v}` may not exist; an admin " + f"can create it from the W&B Registry UI.\n" + f"- On the legacy Model Registry, link with " + f"`target_path='model-registry/{collection_name_v}'` instead." + ), + kind="danger", + ) + ) +else: + mo.output.append( + mo.md( + "_Registry linking is disabled — the artifact is logged to the run " + "but not linked to a collection._" + ) + ) + +# Close the run so its summary and any Registry link finalize server-side. +wandb.finish() +``` + +```python {.marimo} +# Placed after the training cell on purpose: it's the explicit "run" trigger +# for the pipeline above. It must be its own cell because that cell reads +# `train_button.value`, and a widget only drives reactivity when a +# *different* cell consumes it. The gate also stops the pipeline from +# running automatically when the notebook opens; run_button's value is True +# only for the cascade a click triggers (then resets to False), so editing +# the form afterwards re-runs the training cell but it stops immediately. +train_button = mo.ui.run_button(label="Train model", kind="success") +mo.vstack( + [ + train_button, + mo.md( + "Runs the training cell above. It is gated so it does not " + "execute when the notebook opens — click to run, and click " + "again to retrain after editing the form (the previous run is " + "finished first)." + ), + ] +) +``` + +```python {.marimo} +# Consume the model: download it from W&B (preferring the registered +# version, falling back to the run's own artifact), load the weights into a +# fresh network, and classify 10 held-out test digits. +api = wandb.Api() +try: + consumed = api.artifact( + f"wandb-registry-{registry_name_v}/{collection_name_v}:latest", type="model" + ) + source = f"registry `wandb-registry-{registry_name_v}/{collection_name_v}:latest`" +except Exception: # noqa: BLE001 - registry link may be absent (e.g. a view-only seat) + consumed = api.artifact( + f"{run.entity}/{run.project}/mnist-cnn-{run.id}:latest", type="model" + ) + source = f"run artifact `mnist-cnn-{run.id}:latest`" +weights_dir = consumed.download() + +clf = Net() +clf.load_state_dict(torch.load(f"{weights_dir}/mnist_cnn.pt", map_location="cpu")) +clf.eval() + +cards = [] +n_correct = 0 +with torch.no_grad(): + for i in range(10): + image, true_label = test_ds[i] + prediction = clf(image.unsqueeze(0)).argmax(dim=1).item() + n_correct += int(prediction == true_label) + # Undo the Normalize transform so the digit renders as a clean image. + digit = (image * 0.3081 + 0.1307).clamp(0, 1).squeeze().numpy() + mark = "✅" if prediction == true_label else "❌" + cards.append( + mo.vstack( + [ + mo.image(digit, width=64, vmin=0, vmax=1), + mo.md(f"{mark} **{prediction}** · true {true_label}"), + ], + align="center", + ) + ) + +mo.vstack( + [ + mo.md( + f"## Classify 10 test digits\n\nConsumed the model from {source}, " + f"loaded the weights into a fresh network, and ran it on 10 held-out " + f"MNIST test images — **{n_correct}/10 correct**." + ), + mo.hstack(cards, wrap=True, justify="start"), + ] +) +``` + +````python {.marimo hide_code="true"} +# Renders only after a run exists (it consumes `run` from the training +# cell), so it appears once training finishes. +mo.md( + f""" + ## Verify and next steps + + 1. Open the run: [{run.name}]({run.url}) — check the **Charts**, + **System**, and **Examples** panels. + 2. In the run's **Artifacts** tab, confirm `mnist-cnn-{run.id}` is listed + with its metadata (test accuracy, parameter count, hyperparameters). + 3. At [wandb.ai/registry](https://wandb.ai/registry), open the + **{registry_name_v.title()}** registry, then the **{collection_name_v}** + collection, and confirm the linked version. + + **Consume the registered model** from any script or notebook: + + ```python + import wandb + art = wandb.Api().artifact( + "wandb-registry-{registry_name_v}/{collection_name_v}:latest" + ) + art.download() # writes mnist_cnn.pt under ./artifacts/ + ``` + + **Next steps:** promote a version by adding the `production` alias from + the Registry UI; re-run with a deeper architecture or a different + learning rate and compare runs in the W&B UI; or add a W&B Automation to + trigger evaluation when a new version is linked. + """ +) +```` \ No newline at end of file diff --git a/examples/marimo/mnist-registry/mnist_registry.py b/examples/marimo/mnist-registry/mnist_registry.py new file mode 100644 index 00000000..cb78c945 --- /dev/null +++ b/examples/marimo/mnist-registry/mnist_registry.py @@ -0,0 +1,576 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "marimo>=0.9", +# "torch>=2.1", +# "torchvision>=0.16", +# "wandb>=0.18", +# "tqdm", +# ] +# /// +"""Train an MNIST CNN with PyTorch, track the run with Weights & Biases, +and link the resulting model artifact to a W&B Registry collection. + +Run: + + uvx marimo edit mnist_registry.py --sandbox + +The notebook has three interactive cells: fill in the form, click **Train +model**, then read the results. Everything between the inputs and the button +runs as a single step, so one click trains, logs, saves, and registers. +""" + +import marimo + +__generated_with = "0.23.9" +app = marimo.App(width="medium", app_title="MNIST -> W&B Registry") + + +@app.cell(hide_code=True) +def _(): + import marimo as mo + + mo.md( + """ + # MNIST -> W&B Run -> Registry + + ## What you will build + + - A **W&B run** with training and test metrics, gradient histograms, + and example test-set predictions logged as images. + - A **model Artifact** named `mnist-cnn-` of type `model`, + carrying metadata (test accuracy, parameter count, hyperparameters). + - A version of that Artifact **linked into a W&B Registry collection** + so it appears under registered models org-wide. + + ## Prerequisites + + - Authenticate with W&B one of two ways: run **`wandb login`** in + your shell before starting marimo, or paste your key into the + **W&B API key** field in the form below. Get your key from + [wandb.ai/authorize](https://wandb.ai/authorize). + - A W&B **team** to write the run to, set in the **W&B entity** field. + Accounts created after May 2024 have no personal entity, so the run + must go to a team — your username will not work as an entity. + - A **W&B Registry** must exist in your org, and your account needs at + least the **Member** role on it (linking an artifact is a write + action). The built-in Model registry is provisioned automatically in + newer orgs. If linking fails (for example, from a view-only seat), + the run still completes and the Registry step explains how to fix it. + - A GPU is optional. The defaults finish in about 2 minutes on CPU. + """ + ) + return (mo,) + + +@app.cell +def _(mo): + # Imports, device detection, and the input form all live in one cell: this + # is "everything up to collecting your inputs". It defines the form widgets + # but never reads their `.value` — marimo only makes a widget reactive when + # a *different* cell consumes it, which the training cell below does. + import torch + import torch.nn as nn + import torch.nn.functional as F + import torch.optim as optim + from torch.utils.data import DataLoader + from torchvision import datasets, transforms + + import wandb + from tqdm.auto import tqdm + + if torch.cuda.is_available(): + device = torch.device("cuda") + device_note = "CUDA GPU detected. Training will be fast." + device_kind = "success" + elif torch.backends.mps.is_available(): + device = torch.device("mps") + device_note = "Apple MPS detected. Training will run on the GPU." + device_kind = "success" + else: + device = torch.device("cpu") + device_note = ( + "No GPU detected. Training will run on CPU. With the default " + "hyperparameters this takes about 2 minutes." + ) + device_kind = "warn" + + epochs = mo.ui.slider(start=1, stop=10, step=1, value=3, label="Epochs") + batch_size = mo.ui.dropdown( + options=["32", "64", "128", "256"], value="64", label="Batch size" + ) + lr = mo.ui.slider( + start=0.001, stop=0.1, step=0.001, value=0.01, label="Learning rate", show_value=True + ) + momentum = mo.ui.slider( + start=0.0, stop=0.99, step=0.01, value=0.5, label="SGD momentum", show_value=True + ) + seed = mo.ui.number(start=0, stop=99999, value=42, label="Random seed") + + project = mo.ui.text(value="marimo-mnist-registry", label="W&B project") + entity = mo.ui.text( + value="", label="W&B entity — a team you belong to (blank uses your default)" + ) + run_name = mo.ui.text(value="", label="Run name (blank auto-generates)") + api_key = mo.ui.text( + value="", kind="password", label="W&B API key (blank uses your shell login)" + ) + + registry_name = mo.ui.text(value="model", label="W&B Registry name") + collection_name = mo.ui.text(value="MNIST Classifiers", label="Registry collection") + link_to_registry = mo.ui.checkbox(value=True, label="Link artifact to Registry") + + form = mo.vstack( + [ + mo.md("### Training"), + mo.hstack([epochs, batch_size]), + mo.hstack([lr, momentum]), + seed, + mo.md("### W&B run"), + api_key, + mo.hstack([project, entity, run_name]), + mo.md("### Registry"), + mo.hstack([registry_name, collection_name, link_to_registry]), + ] + ) + + mo.vstack( + [ + mo.callout( + mo.md(f"**Device:** `{device}` — {device_note}"), kind=device_kind + ), + mo.md( + "## Configure\n\nSet the hyperparameters and W&B targets, then click " + "**Train model** below. Changing a value here never starts a run on " + "its own — only the button does." + ), + form, + ] + ) + return ( + DataLoader, + F, + api_key, + batch_size, + collection_name, + datasets, + device, + entity, + epochs, + link_to_registry, + lr, + momentum, + nn, + optim, + project, + registry_name, + run_name, + seed, + torch, + tqdm, + transforms, + wandb, + ) + + +@app.cell +def _(F, nn): + class Net(nn.Module): + """Small CNN: 2 conv layers (10, 20 filters, 5x5) + 2 FC (50, 10). + + Defined in its own cell so the training cell and the consume cell can + share it (marimo forbids defining the same name in two cells). + """ + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + return (Net,) + + +@app.cell +def _( + DataLoader, + F, + Net, + api_key, + batch_size, + collection_name, + datasets, + device, + entity, + epochs, + link_to_registry, + lr, + momentum, + mo, + optim, + project, + registry_name, + run_name, + seed, + torch, + tqdm, + train_button, + transforms, + wandb, +): + # Everything the Train button triggers, in one cell — no reason to make you + # advance through a chain of output-less code blocks. Each milestone is + # streamed to the cell output with `mo.output.append` as it happens. + mo.stop( + not train_button.value, + mo.md( + "This cell runs the whole pipeline — start the run, train, log " + "metrics and example predictions, save the model Artifact, and link " + "it to the Registry. Click **Train model** below to run it." + ), + ) + + config = { + "epochs": epochs.value, + "batch_size": int(batch_size.value), + "lr": lr.value, + "momentum": momentum.value, + "seed": seed.value, + "architecture": "CNN", + "dataset": "MNIST", + } + registry_name_v = registry_name.value.strip() + collection_name_v = collection_name.value.strip() + + # Authenticate. Finish any prior run first (marimo keeps the kernel alive + # across re-clicks). A key pasted into the form wins; otherwise fall back to + # ambient login (shell `wandb login`, WANDB_API_KEY, or netrc). The key is + # never written to the run config. + if wandb.run is not None: + wandb.finish() + if api_key.value: + wandb.login(key=api_key.value) + + torch.manual_seed(config["seed"]) + + try: + run = wandb.init( + project=project.value or None, + entity=entity.value or None, + name=run_name.value or None, + config=config, + job_type="train", + ) + except Exception as exc: # noqa: BLE001 - turn the raw traceback into guidance + mo.stop( + True, + mo.callout( + mo.md( + f"**Could not start the run.** `{exc}`\n\n" + f"An `entity ... not found` error means the **W&B entity** is " + f"not a team you can write to. Personal-username entities were " + f"removed for accounts created after 21 May 2024, so set the " + f"**W&B entity** field to one of your teams (find them in the " + f"left sidebar at [wandb.ai](https://wandb.ai))." + ), + kind="danger", + ), + ) + # Use `epoch` as the x-axis for train/test metrics in the W&B UI. + wandb.define_metric("epoch") + wandb.define_metric("train/*", step_metric="epoch") + wandb.define_metric("test/*", step_metric="epoch") + # Surface the run link right away so you can watch metrics stream live. + mo.output.append(mo.md(f"**Run started:** [`{run.name}`]({run.url})")) + + model = Net().to(device) + # `log="gradients"` is the standard choice; `log="all"` also logs parameter + # histograms at extra cost. + wandb.watch(model, log="gradients", log_freq=100) + optimizer = optim.SGD( + model.parameters(), lr=config["lr"], momentum=config["momentum"] + ) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform) + test_ds = datasets.MNIST("./data", train=False, download=True, transform=transform) + loader_kwargs = ( + {"num_workers": 2, "pin_memory": True} if device.type == "cuda" else {} + ) + train_loader = DataLoader( + train_ds, batch_size=config["batch_size"], shuffle=True, **loader_kwargs + ) + test_loader = DataLoader(test_ds, batch_size=1000, shuffle=False, **loader_kwargs) + + history = [] + best_acc = 0.0 + for epoch in range(1, config["epochs"] + 1): + model.train() + for batch_idx, (data, target) in enumerate( + tqdm(train_loader, desc=f"epoch {epoch}/{config['epochs']}") + ): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 50 == 0: + wandb.log({"train/loss": loss.item(), "epoch": epoch}) + + model.eval() + test_loss = 0.0 + correct = 0 + example_images = [] + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + # Pull up to 16 example predictions from the first batch. + while len(example_images) < 16 and len(example_images) < data.size(0): + j = len(example_images) + example_images.append( + wandb.Image( + data[j], + caption=f"pred={pred[j].item()} true={target[j].item()}", + ) + ) + + test_loss /= len(test_loader.dataset) + test_acc = correct / len(test_loader.dataset) + best_acc = max(best_acc, test_acc) + wandb.log( + { + "test/loss": test_loss, + "test/accuracy": test_acc, + "epoch": epoch, + "examples": example_images, + } + ) + history.append( + {"epoch": epoch, "test_loss": round(test_loss, 4), "test_acc": round(test_acc, 4)} + ) + + # Full-precision last-epoch accuracy; `history` rounds only for display. + final_acc = test_acc + mo.output.append( + mo.vstack( + [ + mo.md("### Training summary"), + mo.ui.table(history, selection=None), + mo.md(f"**Final test accuracy:** {final_acc:.2%}"), + ] + ) + ) + + # Save the weights and log them as a model Artifact tagged `latest`. + model_path = "mnist_cnn.pt" + torch.save(model.state_dict(), model_path) + artifact = wandb.Artifact( + name=f"mnist-cnn-{run.id}", + type="model", + description=( + "Small CNN trained on MNIST. Architecture: 2 conv layers " + "(10 and 20 filters, 5x5 kernels) + 2 FC layers (50, 10)." + ), + metadata={ + "framework": "pytorch", + "architecture": "CNN", + "num_parameters": sum(p.numel() for p in model.parameters()), + "dataset": "MNIST", + "train_size": len(train_ds), + "test_size": len(test_ds), + "test_accuracy": final_acc, + "best_test_accuracy": best_acc, + "hyperparameters": dict(config), + }, + ) + artifact.add_file(model_path) + logged = run.log_artifact(artifact, aliases=["latest"]) + # Block until the artifact has committed before linking, to avoid a race. + logged.wait() + mo.output.append(mo.md(f"**Artifact logged:** `{artifact.name}` (alias `latest`)")) + + # Link to the Registry, surfacing a remediation note instead of crashing. + if link_to_registry.value: + target_path = f"wandb-registry-{registry_name_v}/{collection_name_v}" + try: + run.link_artifact(artifact=logged, target_path=target_path) + mo.output.append( + mo.callout( + mo.md( + f"**Linked to Registry:** `{target_path}` — see " + f"[wandb.ai/registry](https://wandb.ai/registry)." + ), + kind="success", + ) + ) + except Exception as exc: # noqa: BLE001 - surface any failure to the reader + mo.output.append( + mo.callout( + mo.md( + f"**Registry link failed.** Target `{target_path}` — `{exc}`\n\n" + f"- Linking needs at least the **Member** role on the " + f"Registry. `view-only member cannot write to project` means " + f"your seat is view-only: the run and artifact succeed, but " + f"linking is blocked. An admin can grant access from the " + f"Registry **Members** settings, the Python SDK " + f"(`wandb.Api().registry(...)` then `add_member()` / " + f"`update_member()`), or SCIM (`PATCH /scim/Users/{{id}}` with " + f"`registryRoles`) — see " + f"https://docs.wandb.ai/guides/registry/configure_registry/. " + f"Or set **W&B entity** to a team in an org where you have " + f"Registry write access.\n" + f"- The Registry `{registry_name_v}` may not exist; an admin " + f"can create it from the W&B Registry UI.\n" + f"- On the legacy Model Registry, link with " + f"`target_path='model-registry/{collection_name_v}'` instead." + ), + kind="danger", + ) + ) + else: + mo.output.append( + mo.md( + "_Registry linking is disabled — the artifact is logged to the run " + "but not linked to a collection._" + ) + ) + + # Close the run so its summary and any Registry link finalize server-side. + wandb.finish() + return collection_name_v, registry_name_v, run, test_ds + + +@app.cell +def _(mo): + # Placed after the training cell on purpose: it's the explicit "run" trigger + # for the pipeline above. It must be its own cell because that cell reads + # `train_button.value`, and a widget only drives reactivity when a + # *different* cell consumes it. The gate also stops the pipeline from + # running automatically when the notebook opens; run_button's value is True + # only for the cascade a click triggers (then resets to False), so editing + # the form afterwards re-runs the training cell but it stops immediately. + train_button = mo.ui.run_button(label="Train model", kind="success") + mo.vstack( + [ + train_button, + mo.md( + "Runs the training cell above. It is gated so it does not " + "execute when the notebook opens — click to run, and click " + "again to retrain after editing the form (the previous run is " + "finished first)." + ), + ] + ) + return (train_button,) + + +@app.cell +def _(Net, collection_name_v, mo, registry_name_v, run, test_ds, torch, wandb): + # Consume the model: download it from W&B (preferring the registered + # version, falling back to the run's own artifact), load the weights into a + # fresh network, and classify 10 held-out test digits. + api = wandb.Api() + try: + consumed = api.artifact( + f"wandb-registry-{registry_name_v}/{collection_name_v}:latest", type="model" + ) + source = f"registry `wandb-registry-{registry_name_v}/{collection_name_v}:latest`" + except Exception: # noqa: BLE001 - registry link may be absent (e.g. a view-only seat) + consumed = api.artifact( + f"{run.entity}/{run.project}/mnist-cnn-{run.id}:latest", type="model" + ) + source = f"run artifact `mnist-cnn-{run.id}:latest`" + weights_dir = consumed.download() + + clf = Net() + clf.load_state_dict(torch.load(f"{weights_dir}/mnist_cnn.pt", map_location="cpu")) + clf.eval() + + cards = [] + n_correct = 0 + with torch.no_grad(): + for i in range(10): + image, true_label = test_ds[i] + prediction = clf(image.unsqueeze(0)).argmax(dim=1).item() + n_correct += int(prediction == true_label) + # Undo the Normalize transform so the digit renders as a clean image. + digit = (image * 0.3081 + 0.1307).clamp(0, 1).squeeze().numpy() + mark = "✅" if prediction == true_label else "❌" + cards.append( + mo.vstack( + [ + mo.image(digit, width=64, vmin=0, vmax=1), + mo.md(f"{mark} **{prediction}** · true {true_label}"), + ], + align="center", + ) + ) + + mo.vstack( + [ + mo.md( + f"## Classify 10 test digits\n\nConsumed the model from {source}, " + f"loaded the weights into a fresh network, and ran it on 10 held-out " + f"MNIST test images — **{n_correct}/10 correct**." + ), + mo.hstack(cards, wrap=True, justify="start"), + ] + ) + return + + +@app.cell(hide_code=True) +def _(collection_name_v, mo, registry_name_v, run): + # Renders only after a run exists (it consumes `run` from the training + # cell), so it appears once training finishes. + mo.md( + f""" + ## Verify and next steps + + 1. Open the run: [{run.name}]({run.url}) — check the **Charts**, + **System**, and **Examples** panels. + 2. In the run's **Artifacts** tab, confirm `mnist-cnn-{run.id}` is listed + with its metadata (test accuracy, parameter count, hyperparameters). + 3. At [wandb.ai/registry](https://wandb.ai/registry), open the + **{registry_name_v.title()}** registry, then the **{collection_name_v}** + collection, and confirm the linked version. + + **Consume the registered model** from any script or notebook: + + ```python + import wandb + art = wandb.Api().artifact( + "wandb-registry-{registry_name_v}/{collection_name_v}:latest" + ) + art.download() # writes mnist_cnn.pt under ./artifacts/ + ``` + + **Next steps:** promote a version by adding the `production` alias from + the Registry UI; re-run with a deeper architecture or a different + learning rate and compare runs in the W&B UI; or add a W&B Automation to + trigger evaluation when a new version is linked. + """ + ) + return + + +if __name__ == "__main__": + app.run() diff --git a/examples/marimo/mnist-registry/requirements.txt b/examples/marimo/mnist-registry/requirements.txt new file mode 100644 index 00000000..69a57c0a --- /dev/null +++ b/examples/marimo/mnist-registry/requirements.txt @@ -0,0 +1,7 @@ +# Mirror of the PEP 723 inline dependency block in mnist_registry.py. +# Keep these two in sync. +marimo>=0.9 +torch>=2.1 +torchvision>=0.16 +wandb>=0.18 +tqdm