[CUDA] Release build for cuda 13 (#2872)

This commit is contained in:
Awni Hannun
2025-12-04 21:42:26 -08:00
committed by GitHub
parent aefc9bd3f6
commit 39289ef025
3 changed files with 39 additions and 11 deletions

View File

@@ -131,6 +131,7 @@ jobs:
strategy: strategy:
matrix: matrix:
arch: ['x86_64', 'aarch64'] arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }} runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
@@ -139,7 +140,7 @@ jobs:
- uses: actions/checkout@v6 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
toolkit: 'cuda-12.9' toolkit: ${{ matrix.toolkit }}
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with: with:

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda] pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
- Nvidia architecture >= SM 7.0 (Volta) - Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@@ -7,13 +7,21 @@ import re
import subprocess import subprocess
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
def cuda_toolkit_major_version():
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
text = out.decode()
m = re.search(r"release (\d+)", text)
if m:
return int(m.group(1))
return None
def get_version(): def get_version():
with open("mlx/version.h", "r") as fid: with open("mlx/version.h", "r") as fid:
for l in fid: for l in fid:
@@ -31,7 +39,7 @@ def get_version():
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
if not pypi_release and not dev_release: if not pypi_release and not dev_release:
git_hash = ( git_hash = (
run( subprocess.run(
"git rev-parse --short HEAD".split(), "git rev-parse --short HEAD".split(),
capture_output=True, capture_output=True,
check=True, check=True,
@@ -284,7 +292,11 @@ if __name__ == "__main__":
install_requires.append( install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"' f'mlx-metal=={version}; platform_system == "Darwin"'
) )
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
for toolkit in [12, 13]:
extras[f"cuda{toolkit}"] = [
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
]
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
@@ -299,13 +311,25 @@ if __name__ == "__main__":
if build_macos: if build_macos:
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
if toolkit == 12:
install_requires += [ install_requires += [
"nvidia-cublas-cu12==12.9.*", "nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*", "nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==9.*",
"nvidia-nccl-cu12",
] ]
elif toolkit == 13:
install_requires += [
"nvidia-cublas-cu13",
"nvidia-cuda-nvrtc-cu13",
]
else:
raise ValueError(f"Unknown toolkit {toolkit}")
install_requires += [
f"nvidia-cudnn-cu{toolkit}==9.*",
f"nvidia-nccl-cu{toolkit}",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(