feat: Add Native CUDA Support for Linux GPU Inference#2103
Conversation
This commit enables exo to run on NVIDIA GPUs on Linux by fixing Metal-specific assumptions in the MLX inference path. Changes: - Add CUDA compatibility shim for mlx-lm's new_thread_local_stream API (Linux CUDA MLX exposes new_stream instead) - Gate MLX_METAL_FAST_SYNCH env var to macOS only, preventing warnings on Linux - Make set_wired_limit_for_model handle CUDA backends gracefully by checking mx.metal.is_available() first - Add automatic LD_LIBRARY_PATH setup in runner bootstrap for CUDA libraries (libcublasLt.so.13, etc.) Compatibility: - Zero breaking changes - all modifications are platform-gated - macOS Metal path unchanged - CPU-only Linux still works - Enables heterogeneous clusters (macOS Metal + Linux CUDA) Tested on: - Linux: NVIDIA RTX 3090, CUDA 13.1, Driver 590.48.01 - macOS: MacBook Pro M5 Pro (Metal) - Verified cross-platform cluster inference with Qwen3-0.6B-8bit Refs: PR exo-explore#2053 (Docker support) - this provides the native code changes needed for Linux CUDA deployment without requiring Docker.
There was a problem hiding this comment.
Pull request overview
This PR aims to enable MLX-based GPU inference on Linux NVIDIA (via mlx-cuda) by removing Metal-only assumptions in the MLX runner path and introducing a small CUDA-compatibility shim so heterogeneous macOS (Metal) + Linux (CUDA) clusters can run together.
Changes:
- Gate
MLX_METAL_FAST_SYNCHconfiguration to macOS-only in the runner bootstrap. - Adjust MLX wired-limit behavior to be backend-aware (Metal vs CUDA) in
utils_mlx.py. - Add and register a CUDA compatibility patch that shims
mx.new_thread_local_streamviamx.new_streamon non-Darwin platforms.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
src/exo/worker/runner/bootstrap.py |
Platform-gates Metal-specific env var (MLX_METAL_FAST_SYNCH) and adjusts logging on non-Darwin. |
src/exo/worker/engines/mlx/utils_mlx.py |
Makes wired-limit handling conditional on Metal availability and logs when CUDA is active. |
src/exo/worker/engines/mlx/patches/cuda_compat.py |
Adds CUDA-side API shim for mlx-lm stream creation expectations. |
src/exo/worker/engines/mlx/patches/__init__.py |
Registers the new CUDA compat patch in the MLX patch chain. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| fast_synch_override = os.environ.get("EXO_FAST_SYNCH") | ||
| if fast_synch_override == "false": | ||
| os.environ["MLX_METAL_FAST_SYNCH"] = "0" | ||
| if sys.platform == "darwin": | ||
| if fast_synch_override == "false": | ||
| os.environ["MLX_METAL_FAST_SYNCH"] = "0" | ||
| else: | ||
| os.environ["MLX_METAL_FAST_SYNCH"] = "1" | ||
| logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") | ||
| else: | ||
| os.environ["MLX_METAL_FAST_SYNCH"] = "1" | ||
|
|
||
| logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") | ||
| logger.info("Fast synch flag: skipped (non-Darwin platform)") |
| def apply_cuda_compat_patches() -> None: | ||
| """Apply MLX CUDA compatibility patches. | ||
|
|
||
| These patches are only applied on Linux systems where MLX uses the CUDA backend. | ||
| They are no-ops on macOS or CPU-only Linux. | ||
| """ | ||
| if sys.platform == "darwin": | ||
| return | ||
|
|
||
| # mlx-lm expects new_thread_local_stream, but Linux CUDA MLX exposes new_stream. | ||
| # Patch mx to provide the expected API. | ||
| if not hasattr(mx, "new_thread_local_stream") and hasattr(mx, "new_stream"): | ||
| mx.new_thread_local_stream = mx.new_stream # type: ignore[attr-defined] |
AcknowledgmentsThis PR builds upon the MLX CUDA backend recently introduced by Apple's MLX team. Special thanks to Cheng (@zcbenz) and the Apple MLX team for their extensive work on bringing CUDA support to MLX, which made this cross-platform GPU inference possible. What This PR AddsWhile MLX CUDA backend provides the foundation for running MLX on NVIDIA GPUs, exo (the distributed inference framework) had several Metal-specific assumptions that prevented it from working with MLX CUDA. This PR fixes those assumptions, enabling:
The MLX CUDA backend is the engine; this PR is the adapter that makes exo work with that engine on Linux. |
…#2103) Cherry-picked from Winston-9527/feature/linux-cuda-support (05e9811) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com>
feat: Add Native CUDA Support for Linux GPU Inference
Summary
This PR enables exo to run on NVIDIA GPUs on Linux by fixing Metal-specific assumptions in the MLX inference path. This is the native code complement to PR #2053 (Docker support) — providing Linux CUDA deployment without requiring Docker.
Key Achievement: exo now supports heterogeneous clusters with both macOS (Metal) and Linux (CUDA) nodes working together.
Problem Statement
Currently, exo on Linux only runs on CPU. When using
mlx-cuda, several issues prevent GPU inference:set_wired_limit_for_model()assumes Metal is always available and crashes on CUDAMLX_METAL_FAST_SYNCHis unconditionally set, causing warnings on Linuxmlx-lmexpectsmx.new_thread_local_streamwhich doesn't exist in Linux CUDA MLXlibcublasLt.so.13even when parent process has it inLD_LIBRARY_PATHSolution Overview
Changes Made
1. CUDA Compatibility Patches (
src/exo/worker/engines/mlx/patches/cuda_compat.py)mx.new_thread_local_stream→mx.new_stream(Linux CUDA exposes different API)2. Platform-gated Metal Settings (
src/exo/worker/runner/bootstrap.py)MLX_METAL_FAST_SYNCHonly set on macOS (sys.platform == "darwin")LD_LIBRARY_PATHsetup for runner subprocess to find CUDA libraries3. Backend-aware Wired Limit (
src/exo/worker/engines/mlx/utils_mlx.py)set_wired_limit_for_model()now checksmx.metal.is_available()first4. Patch Registration (
src/exo/worker/engines/mlx/patches/__init__.py)apply_cuda_compat_patches()in the patch application chainCompatibility Impact
Zero breaking changes:
MlxCudabackend supportTesting Evidence
Environment
Test Results
Performance Benchmarks
Performance Analysis:
Screenshots
Screenshot 1: Linux CUDA Single-Device
Screenshot 2: macOS Metal Single-Device
Screenshot 3: Cross-Platform Cluster (macOS + Linux)
Usage
Prerequisites
uv sync --extra mlx-cuda13(ormlx-cuda12)CUDA_HOMEfor kernel compilationLaunch
Related
Checklist
uv run ruff check)