Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/exo/worker/engines/mlx/patches/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from exo.worker.engines.mlx.patches.cuda_compat import apply_cuda_compat_patches
from exo.worker.engines.mlx.patches.opt_batch_gen import apply_batch_gen_patch
from exo.worker.engines.mlx.patches.standard_yarn_rope import patch_yarn_rope

Expand All @@ -9,5 +10,6 @@ def apply_mlx_patches() -> None:
if _applied:
return
_applied = True
apply_cuda_compat_patches()
patch_yarn_rope()
apply_batch_gen_patch()
24 changes: 24 additions & 0 deletions src/exo/worker/engines/mlx/patches/cuda_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""CUDA compatibility patches for MLX on Linux.

MLX on Linux CUDA has some API differences from macOS Metal.
This module provides shims to bridge those gaps.
"""

import sys

import mlx.core as mx


def apply_cuda_compat_patches() -> None:
"""Apply MLX CUDA compatibility patches.

These patches are only applied on Linux systems where MLX uses the CUDA backend.
They are no-ops on macOS or CPU-only Linux.
"""
if sys.platform == "darwin":
return

# mlx-lm expects new_thread_local_stream, but Linux CUDA MLX exposes new_stream.
# Patch mx to provide the expected API.
if not hasattr(mx, "new_thread_local_stream") and hasattr(mx, "new_stream"):
mx.new_thread_local_stream = mx.new_stream # type: ignore[attr-defined]
Comment on lines +12 to +24
37 changes: 15 additions & 22 deletions src/exo/worker/engines/mlx/utils_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,29 +793,22 @@ def mlx_force_oom(size: int = 200000) -> None:
mx.eval(f)


def set_wired_limit_for_model(model_size: Memory):
"""
A context manager to temporarily change the wired limit.

Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
if not mx.metal.is_available():
return

max_rec_size = Memory.from_bytes(
int(mx.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
logger.warning(
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
def set_wired_limit_for_model(model_size: Memory) -> None:
if mx.metal.is_available():
max_rec_size = Memory.from_bytes(
int(mx.device_info()["max_recommended_working_set_size"])
)
mx.set_wired_limit(max_rec_size.in_bytes)
logger.info(f"Wired limit set to {max_rec_size}.")
if model_size > 0.9 * max_rec_size:
logger.warning(
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
mx.set_wired_limit(max_rec_size.in_bytes)
logger.info(f"Wired limit set to {max_rec_size}.")
elif hasattr(mx, "cuda") and mx.cuda.is_available():
logger.info("CUDA backend active — skipping Metal wired limit.")


def mlx_cleanup(
Expand Down
13 changes: 8 additions & 5 deletions src/exo/worker/runner/bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import resource
import sys
import traceback
from dataclasses import dataclass
from typing import Self, cast
Expand Down Expand Up @@ -51,12 +52,14 @@ def entrypoint(
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))

fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "false":
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
if sys.platform == "darwin":
if fast_synch_override == "false":
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "1"

logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
logger.info("Fast synch flag: skipped (non-Darwin platform)")
Comment on lines 54 to +62

# Import main after setting global logger - this lets us just import logger from this module
try:
Expand Down
Loading