mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Release build for cuda 13 (#2872)
This commit is contained in:
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -131,6 +131,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ['x86_64', 'aarch64']
|
||||
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
@@ -139,7 +140,7 @@ jobs:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
|
||||
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
pip install mlx[cuda12]
|
||||
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia architecture >= SM 7.5
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
36
setup.py
36
setup.py
@@ -7,13 +7,21 @@ import re
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
|
||||
from setuptools import Command, Extension, find_namespace_packages, setup
|
||||
from setuptools.command.bdist_wheel import bdist_wheel
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def cuda_toolkit_major_version():
|
||||
out = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT)
|
||||
text = out.decode()
|
||||
m = re.search(r"release (\d+)", text)
|
||||
if m:
|
||||
return int(m.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def get_version():
|
||||
with open("mlx/version.h", "r") as fid:
|
||||
for l in fid:
|
||||
@@ -31,7 +39,7 @@ def get_version():
|
||||
version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}"
|
||||
if not pypi_release and not dev_release:
|
||||
git_hash = (
|
||||
run(
|
||||
subprocess.run(
|
||||
"git rev-parse --short HEAD".split(),
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@@ -284,7 +292,11 @@ if __name__ == "__main__":
|
||||
install_requires.append(
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||
)
|
||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||
extras["cuda"] = [f'mlx-cuda-12=={version}; platform_system == "Linux"']
|
||||
for toolkit in [12, 13]:
|
||||
extras[f"cuda{toolkit}"] = [
|
||||
f'mlx-cuda-{toolkit}=={version}; platform_system == "Linux"'
|
||||
]
|
||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||
|
||||
_setup(
|
||||
@@ -299,13 +311,25 @@ if __name__ == "__main__":
|
||||
if build_macos:
|
||||
name = "mlx-metal"
|
||||
elif build_cuda:
|
||||
name = "mlx-cuda"
|
||||
toolkit = cuda_toolkit_major_version()
|
||||
name = f"mlx-cuda-{toolkit}"
|
||||
if toolkit == 12:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu12==12.9.*",
|
||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||
"nvidia-cudnn-cu12==9.*",
|
||||
"nvidia-nccl-cu12",
|
||||
]
|
||||
elif toolkit == 13:
|
||||
install_requires += [
|
||||
"nvidia-cublas-cu13",
|
||||
"nvidia-cuda-nvrtc-cu13",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown toolkit {toolkit}")
|
||||
install_requires += [
|
||||
f"nvidia-cudnn-cu{toolkit}==9.*",
|
||||
f"nvidia-nccl-cu{toolkit}",
|
||||
]
|
||||
|
||||
else:
|
||||
name = "mlx-cpu"
|
||||
_setup(
|
||||
|
||||
Reference in New Issue
Block a user