Add more CUDA architectures for PyPi package (#2427)

* add cuda sm 90

* add more archs
This commit is contained in:
Awni Hannun
2025-07-28 12:35:15 -07:00
committed by GitHub
parent ab0e608862
commit 641be9463b
4 changed files with 13 additions and 30 deletions

View File

@@ -44,6 +44,8 @@ def get_version():
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
# A CMakeExtension needs a sourcedir instead of a file list.
@@ -85,6 +87,11 @@ class CMakeBuild(build_ext):
"-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120"))
cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"]
# Some generators require explcitly passing config when building.
build_args = ["--config", cfg]
# Adding CMake arguments set as environment variable
@@ -95,7 +102,7 @@ class CMakeBuild(build_ext):
# Pass version to C++
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
if platform.system() == "Darwin":
if build_macos:
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
@@ -202,9 +209,6 @@ if __name__ == "__main__":
],
)
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
version = get_version()
_setup = partial(