Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import os
import re
from dataclasses import dataclass
from typing import NoReturn, TypedDict

Expand Down Expand Up @@ -74,13 +75,24 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str]
attachments.append(f' Directory does not exist: "{dir_path}"')


def _filename_with_sm_arch(filename: str, sm_arch: str) -> str:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small cleanup: switching to fullmatch, compile the SM arch pattern once, and reuse .pattern in the error message. That keeps the validation contract in one place instead of repeating the regex in code and text.

_SM_ARCH_PATTERN = re.compile(r"sm[0-9]+[a-z]?")


def _filename_with_sm_arch(filename: str, sm_arch: str) -> str:
    if not sm_arch:
        return filename

    if not _SM_ARCH_PATTERN.fullmatch(sm_arch):
        raise ValueError(f"Invalid sm_arch: {sm_arch!r} must match {_SM_ARCH_PATTERN.pattern!r}")

    stem, ext = os.path.splitext(filename)
    return f"{stem}_{sm_arch}{ext}"

if not sm_arch:
return filename

if not re.match(r"^sm[0-9]+[a-z]?$", sm_arch):
raise ValueError(f"Invalid sm_arch: '{sm_arch}' must match 'sm[0-9]+[a-z]?'")

stem, ext = os.path.splitext(filename)
return f"{stem}_{sm_arch}{ext}"


class _FindBitcodeLib:
def __init__(self, name: str) -> None:
def __init__(self, name: str, sm_arch: str = "") -> None:
if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference
raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}")
self.name: str = name
self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] # Updated reference
self.filename: str = self.config["filename"]
self.filename: str = _filename_with_sm_arch(self.config["filename"], sm_arch)
self.rel_path: str = self.config["rel_path"]
self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"]
self.error_messages: list[str] = []
Expand Down Expand Up @@ -130,14 +142,23 @@ def raise_not_found_error(self) -> NoReturn:
raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}')


def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:
def locate_bitcode_lib(name: str, *, sm_arch: str = "") -> LocatedBitcodeLib:
"""Locate a bitcode library by name.

When ``sm_arch`` is set, locate the architecture-specific bitcode filename
with ``_{sm_arch}`` inserted before the ``.bc`` suffix.

Args:
name: Name of the supported bitcode library to locate.
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
``"sm90a"``. If set, it must match ``sm[0-9]+[a-z]?``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``name`` is not a supported bitcode library, or if
``sm_arch`` is set but does not match ``sm[0-9]+[a-z]?``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
finder = _FindBitcodeLib(name)
finder = _FindBitcodeLib(name, sm_arch)

abs_path = finder.try_site_packages()
if abs_path is not None:
Expand Down Expand Up @@ -170,11 +191,20 @@ def locate_bitcode_lib(name: str) -> LocatedBitcodeLib:


@functools.cache
def find_bitcode_lib(name: str) -> str:
def find_bitcode_lib(name: str, sm_arch: str = "") -> str:
"""Find the absolute path to a bitcode library.

When ``sm_arch`` is set, find the architecture-specific bitcode filename
with ``_{sm_arch}`` inserted before the ``.bc`` suffix.

Args:
name: Name of the supported bitcode library to find.
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
``"sm90a"``. If set, it must match ``sm[0-9]+[a-z]?``.

Raises:
ValueError: If ``name`` is not a supported bitcode library.
ValueError: If ``name`` is not a supported bitcode library, or if
``sm_arch`` is set but does not match ``sm[0-9]+[a-z]?``.
BitcodeLibNotFoundError: If the bitcode library cannot be found.
"""
return locate_bitcode_lib(name).abs_path
return locate_bitcode_lib(name, sm_arch).abs_path
144 changes: 134 additions & 10 deletions cuda_pathfinder/tests/test_find_bitcode_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ def _bitcode_lib_info(libname: str):
return find_bitcode_lib_module._SUPPORTED_BITCODE_LIBS_INFO[libname]


def _bitcode_lib_filename(libname: str) -> str:
return _bitcode_lib_info(libname)["filename"]


@pytest.fixture
def clear_find_bitcode_lib_cache():
find_bitcode_lib_module.find_bitcode_lib.cache_clear()
Expand All @@ -36,9 +32,9 @@ def clear_find_bitcode_lib_cache():
get_cuda_path_or_home.cache_clear()


def _make_bitcode_lib_file(dir_path: Path, libname: str) -> str:
def _make_bitcode_lib_file(dir_path: Path, filename: str) -> str:
dir_path.mkdir(parents=True, exist_ok=True)
file_path = dir_path / _bitcode_lib_filename(libname)
file_path = dir_path / filename
file_path.touch()
return str(file_path)

Expand Down Expand Up @@ -92,14 +88,16 @@ def test_locate_bitcode_lib(info_summary_append, libname):
@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.parametrize("libname", SUPPORTED_BITCODE_LIBS)
def test_locate_bitcode_lib_search_order(monkeypatch, tmp_path, libname):
filename = _bitcode_lib_info(libname)["filename"]

site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), libname)
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), libname)
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
Expand Down Expand Up @@ -135,6 +133,84 @@ def find_expected_sub_dir(sub_dir):
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_locate_bitcode_lib_with_sm_arch_search_order(monkeypatch, tmp_path):
libname = "nvshmem_device"
sm_arch = "sm90"
filename = "libnvshmem_device_sm90.bc"

site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
site_packages_path = _make_bitcode_lib_file(site_packages_lib_dir, filename)

conda_prefix = tmp_path / "conda-prefix"
conda_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(_conda_anchor(conda_prefix), libname), filename)

cuda_home = tmp_path / "cuda-home"
cuda_home_path = _make_bitcode_lib_file(_bitcode_lib_dir_under(cuda_home, libname), filename)

site_packages_sub_dirs = tuple(
tuple(rel_dir.split("/")) for rel_dir in _bitcode_lib_info(libname)["site_packages_dirs"]
)

def find_expected_sub_dir(sub_dir):
assert sub_dir in site_packages_sub_dirs
if sub_dir == site_packages_sub_dirs[0]:
return [str(site_packages_lib_dir)]
return []

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
find_expected_sub_dir,
)
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == site_packages_path
assert located_lib.filename == filename
assert located_lib.found_via == "site-packages"
assert find_bitcode_lib(libname, sm_arch=sm_arch) == site_packages_path
os.remove(site_packages_path)
find_bitcode_lib_module.find_bitcode_lib.cache_clear()

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == conda_path
assert located_lib.filename == filename
assert located_lib.found_via == "conda"
os.remove(conda_path)

located_lib = locate_bitcode_lib(libname, sm_arch=sm_arch)
assert located_lib.abs_path == cuda_home_path
assert located_lib.filename == filename
assert located_lib.found_via == "CUDA_PATH"


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_find_bitcode_lib_cache_keeps_sm_arch_separate(monkeypatch, tmp_path):
libname = "nvshmem_device"
site_packages_lib_dir = _site_packages_bitcode_lib_dir_under(tmp_path / "site-packages", libname)
sm80_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm80.bc")
sm90_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90.bc")
sm90a_path = _make_bitcode_lib_file(site_packages_lib_dir, "libnvshmem_device_sm90a.bc")

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [str(site_packages_lib_dir)],
)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.delenv("CUDA_HOME", raising=False)
monkeypatch.delenv("CUDA_PATH", raising=False)

assert find_bitcode_lib(libname, sm_arch="sm80") == sm80_path
assert find_bitcode_lib(libname, sm_arch="sm90") == sm90_path
assert find_bitcode_lib(libname, sm_arch="sm90a") == sm90a_path


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path):
cuda_home = tmp_path / "cuda-home"
Expand All @@ -156,12 +232,44 @@ def test_find_bitcode_lib_not_found_error_includes_cuda_home_directory_listing(m
find_bitcode_lib("device")

message = str(exc_info.value)
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_filename("device"))
expected_missing_file = os.path.join(str(lib_dir), _bitcode_lib_info("device")["filename"])
assert f"No such file: {expected_missing_file}" in message
assert f'listdir("{lib_dir}"):' in message
assert "README.txt" in message


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
@pytest.mark.skipif("nvshmem_device" not in SUPPORTED_BITCODE_LIBS, reason="nvshmem_device is not supported")
def test_find_bitcode_lib_with_sm_arch_not_found_error_uses_arch_specific_filename(monkeypatch, tmp_path):
libname = "nvshmem_device"
sm_arch = "sm90"
expected_filename = "libnvshmem_device_sm90.bc"

cuda_home = tmp_path / "cuda-home"
lib_dir = _bitcode_lib_dir_under(cuda_home, libname)
lib_dir.mkdir(parents=True, exist_ok=True)
extra_file = lib_dir / "libnvshmem_device.bc"
extra_file.touch()

monkeypatch.setattr(
find_bitcode_lib_module,
"find_sub_dirs_all_sitepackages",
lambda _sub_dir: [],
)
monkeypatch.delenv("CONDA_PREFIX", raising=False)
monkeypatch.setenv("CUDA_HOME", str(cuda_home))
monkeypatch.delenv("CUDA_PATH", raising=False)

with pytest.raises(BitcodeLibNotFoundError, match=rf'Failure finding "{expected_filename}"') as exc_info:
find_bitcode_lib(libname, sm_arch=sm_arch)

message = str(exc_info.value)
expected_missing_file = os.path.join(str(lib_dir), expected_filename)
assert f"No such file: {expected_missing_file}" in message
assert f'listdir("{lib_dir}"):' in message
assert "libnvshmem_device.bc" in message


@pytest.mark.usefixtures("clear_find_bitcode_lib_cache")
def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch):
monkeypatch.setattr(
Expand All @@ -183,3 +291,19 @@ def test_find_bitcode_lib_not_found_error_without_cuda_home(monkeypatch):
def test_find_bitcode_lib_invalid_name():
with pytest.raises(ValueError, match="Unknown bitcode library"):
find_bitcode_lib_module.locate_bitcode_lib("invalid")


@pytest.mark.parametrize(
"sm_arch",
[
"../sm90",
"compute90",
"sm_90",
"sm",
"sm90/extra",
"sm90A",
],
)
def test_find_bitcode_lib_invalid_sm_arch(sm_arch):
with pytest.raises(ValueError, match="must match"):
find_bitcode_lib_module.locate_bitcode_lib("device", sm_arch=sm_arch)
Loading