diff --git a/src/exo/worker/engines/mlx/patches/__init__.py b/src/exo/worker/engines/mlx/patches/__init__.py index 5e86c08c2e..513d7e1ff1 100644 --- a/src/exo/worker/engines/mlx/patches/__init__.py +++ b/src/exo/worker/engines/mlx/patches/__init__.py @@ -1,3 +1,4 @@ +from exo.worker.engines.mlx.patches.cuda_compat import apply_cuda_compat_patches from exo.worker.engines.mlx.patches.opt_batch_gen import apply_batch_gen_patch from exo.worker.engines.mlx.patches.standard_yarn_rope import patch_yarn_rope @@ -9,5 +10,6 @@ def apply_mlx_patches() -> None: if _applied: return _applied = True + apply_cuda_compat_patches() patch_yarn_rope() apply_batch_gen_patch() diff --git a/src/exo/worker/engines/mlx/patches/cuda_compat.py b/src/exo/worker/engines/mlx/patches/cuda_compat.py new file mode 100644 index 0000000000..11ed3452f2 --- /dev/null +++ b/src/exo/worker/engines/mlx/patches/cuda_compat.py @@ -0,0 +1,24 @@ +"""CUDA compatibility patches for MLX on Linux. + +MLX on Linux CUDA has some API differences from macOS Metal. +This module provides shims to bridge those gaps. +""" + +import sys + +import mlx.core as mx + + +def apply_cuda_compat_patches() -> None: + """Apply MLX CUDA compatibility patches. + + These patches are only applied on Linux systems where MLX uses the CUDA backend. + They are no-ops on macOS or CPU-only Linux. + """ + if sys.platform == "darwin": + return + + # mlx-lm expects new_thread_local_stream, but Linux CUDA MLX exposes new_stream. + # Patch mx to provide the expected API. + if not hasattr(mx, "new_thread_local_stream") and hasattr(mx, "new_stream"): + mx.new_thread_local_stream = mx.new_stream # type: ignore[attr-defined] diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 730abf64e3..535c38d9ce 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -793,29 +793,22 @@ def mlx_force_oom(size: int = 200000) -> None: mx.eval(f) -def set_wired_limit_for_model(model_size: Memory): - """ - A context manager to temporarily change the wired limit. - - Note, the wired limit should not be changed during an async eval. If an - async eval could be running pass in the streams to synchronize with prior - to exiting the context manager. - """ - if not mx.metal.is_available(): - return - - max_rec_size = Memory.from_bytes( - int(mx.device_info()["max_recommended_working_set_size"]) - ) - if model_size > 0.9 * max_rec_size: - logger.warning( - f"Generating with a model that requires {model_size.in_float_mb:.1f} MB " - f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} " - "MB. This can be slow. See the documentation for possible work-arounds: " - "https://github.com/ml-explore/mlx-lm/tree/main#large-models" +def set_wired_limit_for_model(model_size: Memory) -> None: + if mx.metal.is_available(): + max_rec_size = Memory.from_bytes( + int(mx.device_info()["max_recommended_working_set_size"]) ) - mx.set_wired_limit(max_rec_size.in_bytes) - logger.info(f"Wired limit set to {max_rec_size}.") + if model_size > 0.9 * max_rec_size: + logger.warning( + f"Generating with a model that requires {model_size.in_float_mb:.1f} MB " + f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} " + "MB. This can be slow. See the documentation for possible work-arounds: " + "https://github.com/ml-explore/mlx-lm/tree/main#large-models" + ) + mx.set_wired_limit(max_rec_size.in_bytes) + logger.info(f"Wired limit set to {max_rec_size}.") + elif hasattr(mx, "cuda") and mx.cuda.is_available(): + logger.info("CUDA backend active — skipping Metal wired limit.") def mlx_cleanup( diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index f981667c0a..b6d76b0cd8 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -1,5 +1,6 @@ import os import resource +import sys import traceback from dataclasses import dataclass from typing import Self, cast @@ -51,12 +52,14 @@ def entrypoint( resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard)) fast_synch_override = os.environ.get("EXO_FAST_SYNCH") - if fast_synch_override == "false": - os.environ["MLX_METAL_FAST_SYNCH"] = "0" + if sys.platform == "darwin": + if fast_synch_override == "false": + os.environ["MLX_METAL_FAST_SYNCH"] = "0" + else: + os.environ["MLX_METAL_FAST_SYNCH"] = "1" + logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") else: - os.environ["MLX_METAL_FAST_SYNCH"] = "1" - - logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}") + logger.info("Fast synch flag: skipped (non-Darwin platform)") # Import main after setting global logger - this lets us just import logger from this module try: