From 39289ef0256d9a88b39c5923ce4f99e7c8757898 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Dec 2025 21:42:26 -0800 Subject: [PATCH] [CUDA] Release build for cuda 13 (#2872) --- .github/workflows/release.yml | 3 ++- docs/src/install.rst | 7 ++++-- setup.py | 40 ++++++++++++++++++++++++++++------- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6087eed0b..0e8eb20c0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -131,6 +131,7 @@ jobs: strategy: matrix: arch: ['x86_64', 'aarch64'] + toolkit: ['cuda-12.9', 'cuda-13.0'] runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }} env: PYPI_RELEASE: 1 @@ -139,7 +140,7 @@ jobs: - uses: actions/checkout@v6 - uses: ./.github/actions/setup-linux with: - toolkit: 'cuda-12.9' + toolkit: ${{ matrix.toolkit }} - name: Build Python package uses: ./.github/actions/build-cuda-release with: diff --git a/docs/src/install.rst b/docs/src/install.rst index 9f0ab67bc..9c72fdb46 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with: .. code-block:: shell - pip install mlx[cuda] + pip install mlx[cuda12] + To install the CUDA package from PyPi your system must meet the following requirements: -- Nvidia architecture >= SM 7.0 (Volta) +- Nvidia architecture >= SM 7.5 - Nvidia driver >= 550.54.14 - CUDA toolkit >= 12.0 - Linux distribution with glibc >= 2.35 - Python >= 3.10 +For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires +an Nvidia driver >= 580 or an appropriate CUDA compatibility package. CPU-only (Linux) ^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 077d6f0f3..1134ddaa3 100644 --- a/setup.py +++ b/setup.py @@ -7,13 +7,21 @@ import re import subprocess from functools import partial from pathlib import Path -from subprocess import run from setuptools import Command, Extension, find_namespace_packages, setup from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext +def cuda_toolkit_major_version(): + out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT) + text = out.decode() + m = re.search(r"release (\d+)", text) + if m: + return int(m.group(1)) + return None + + def get_version(): with open("mlx/version.h", "r") as fid: for l in fid: @@ -31,7 +39,7 @@ def get_version(): version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" if not pypi_release and not dev_release: git_hash = ( - run( + subprocess.run( "git rev-parse --short HEAD".split(), capture_output=True, check=True, @@ -284,7 +292,11 @@ if __name__ == "__main__": install_requires.append( f'mlx-metal=={version}; platform_system == "Darwin"' ) - extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] + extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"'] + for toolkit in [12, 13]: + extras[f"cuda{toolkit}"] = [ + f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"' + ] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] _setup( @@ -299,13 +311,25 @@ if __name__ == "__main__": if build_macos: name = "mlx-metal" elif build_cuda: - name = "mlx-cuda" + toolkit = cuda_toolkit_major_version() + name = f"mlx-cuda-{toolkit}" + if toolkit == 12: + install_requires += [ + "nvidia-cublas-cu12==12.9.*", + "nvidia-cuda-nvrtc-cu12==12.9.*", + ] + elif toolkit == 13: + install_requires += [ + "nvidia-cublas-cu13", + "nvidia-cuda-nvrtc-cu13", + ] + else: + raise ValueError(f"Unknown toolkit {toolkit}") install_requires += [ - "nvidia-cublas-cu12==12.9.*", - "nvidia-cuda-nvrtc-cu12==12.9.*", - "nvidia-cudnn-cu12==9.*", - "nvidia-nccl-cu12", + f"nvidia-cudnn-cu{toolkit}==9.*", + f"nvidia-nccl-cu{toolkit}", ] + else: name = "mlx-cpu" _setup(