Install linux with mlx[cuda] and mlx[cpu] (#2356)

* install linux with mlx[cuda] and mlx[cpu]

* temp for testing

* cleanup circle, fix cuda repair

* update circle

* update circle

* decouple python bindings from core libraries
This commit is contained in:
Awni Hannun 2025-07-14 17:17:33 -07:00 committed by GitHub
parent 49114f28ab
commit f0a0b077a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 264 additions and 212 deletions

View File

@ -7,18 +7,6 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@ -282,7 +270,17 @@ jobs:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
steps: steps:
@ -299,59 +297,70 @@ jobs:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
extra_env: build_env:
type: string type: string
default: "DEV_RELEASE=1" default: ""
docker: machine:
- image: ubuntu:20.04 image: ubuntu-2204:current
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
name: Build wheel name: Build wheel
command: | command: |
PYTHON=python<< parameters.python_version >> PYTHON=python<< parameters.python_version >>
apt-get update export DEBIAN_FRONTEND=noninteractive
apt-get upgrade -y export NEEDRESTART_MODE=a
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata sudo apt-get update
apt-get install -y apt-utils sudo apt-get upgrade -y
apt-get install -y software-properties-common TZ=Etc/UTC sudo apt-get -y install tzdata
add-apt-repository -y ppa:deadsnakes/ppa sudo apt-get install -y apt-utils
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full sudo apt-get install -y software-properties-common
apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y build-essential git sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env $PYTHON -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.extra_env >> pip install . -v << parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> python -m build --wheel MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
auditwheel show dist/* bash python/scripts/repair_linux.sh
auditwheel repair dist/* --plat manylinux_2_31_x86_64 - when:
- run: condition:
name: Upload package equal: ["3.9", << parameters.python_version >>]
command: | steps:
source env/bin/activate - run:
twine upload wheelhouse/* name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release: build_cuda_release:
parameters: parameters:
python_version: build_env:
type: string type: string
default: "3.9" default: ""
extra_env:
type: string
default: "DEV_RELEASE=1"
machine: machine:
image: linux-cuda-12:default image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
@ -362,25 +371,25 @@ jobs:
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
python -m venv env python -m venv env
source env/bin/activate source env/bin/activate
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.extra_env >> \ << parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v python -m build -w
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh
- run: - when:
name: Upload package condition: << parameters.build_env >>
command: | steps:
source env/bin/activate - run:
twine upload wheelhouse/*.whl name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@ -392,8 +401,6 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
@ -407,8 +414,6 @@ workflows:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
filters: filters:
@ -499,7 +504,16 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@ -578,99 +592,8 @@ workflows:
- macosx_deployment_target: "15.0" - macosx_deployment_target: "15.0"
xcode_version: "15.0.0" xcode_version: "15.0.0"
python_version: "3.13" python_version: "3.13"
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] - build_cuda_release
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@ -64,10 +64,8 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------

View File

@ -23,13 +23,6 @@ To install from PyPI you must meet the following requirements:
MLX is only available on devices running macOS >= 13.5 MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma) It is highly recommended to use macOS 14 (Sonoma)
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
conda install conda-forge::mlx
CUDA CUDA
^^^^ ^^^^
@ -38,8 +31,16 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell .. code-block:: shell
pip install mlx-cuda pip install "mlx[cuda]"
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install "mlx[cpu]"
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@ -1,6 +1,6 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=42", "setuptools>=80",
"nanobind==2.4.0", "nanobind==2.4.0",
"cmake>=3.25", "cmake>=3.25",
] ]

View File

@ -1,17 +1,23 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \ --plat manylinux_2_39_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* --exclude libnvrtc* \
-w wheel_tmp
cd wheelhouse
mkdir wheelhouse
cd wheel_tmp
repaired_wheel=$(find . -name "*.whl" -print -quit) repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}" unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit) rm "${repaired_wheel}"
rpath=$(patchelf --print-rpath "${core_so}") mlx_so="mlx/lib/libmlx.so"
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib rpath=$(patchelf --print-rpath "${mlx_so}")
patchelf --force-rpath --set-rpath "$rpath" "$core_so" base="\$ORIGIN/../../nvidia"
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
python ../python/scripts/repair_record.py ${mlx_so}
# Re-zip the repaired wheel # Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" . zip -r -q "../wheelhouse/${repaired_wheel}" .

View File

@ -0,0 +1,19 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libmlx* \
-w wheel_tmp
mkdir wheelhouse
cd wheel_tmp
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
rm "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath="\$ORIGIN/lib"
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
python ../python/scripts/repair_record.py ${core_so}
# Re-zip the repaired wheel
zip -r -q "../wheelhouse/${repaired_wheel}" .

View File

@ -0,0 +1,33 @@
import base64
import glob
import hashlib
import sys
filename = sys.argv[1]
# Compute the new hash and size
def urlsafe_b64encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=")
hasher = hashlib.sha256()
with open(filename, "rb") as f:
data = f.read()
hasher.update(data)
hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii")
size = len(data)
# Update the record file
record_file = glob.glob("*/RECORD")[0]
with open(record_file, "r") as f:
lines = [l.split(",") for l in f.readlines()]
for l in lines:
if filename == l[0]:
l[1] = hash_str
l[2] = f"{size}\n"
with open(record_file, "w") as f:
for l in lines:
f.write(",".join(l))

156
setup.py
View File

@ -5,10 +5,12 @@ import os
import platform import platform
import re import re
import subprocess import subprocess
from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run from subprocess import run
from setuptools import Command, Extension, find_namespace_packages, setup from setuptools import Command, Extension, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
@ -41,6 +43,9 @@ def get_version():
return version return version
build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0))
# A CMakeExtension needs a sourcedir instead of a file list. # A CMakeExtension needs a sourcedir instead of a file list.
# The name must be the _single_ output extension from the CMake build. # The name must be the _single_ output extension from the CMake build.
# If you need multiple extensions, see scikit-build. # If you need multiple extensions, see scikit-build.
@ -59,13 +64,22 @@ class CMakeBuild(build_ext):
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release" cfg = "Debug" if debug else "Release"
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON build_temp = Path(self.build_temp) / ext.name
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code if not build_temp.exists():
# from Python. build_temp.mkdir(parents=True)
build_python = "ON"
install_prefix = f"{extdir}{os.sep}"
if build_stage == 1:
# Don't include MLX libraries in the wheel
install_prefix = f"{build_temp}"
elif build_stage == 2:
# Don't include Python bindings in the wheel
build_python = "OFF"
cmake_args = [ cmake_args = [
f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}", f"-DCMAKE_INSTALL_PREFIX={install_prefix}",
f"-DCMAKE_BUILD_TYPE={cfg}", f"-DCMAKE_BUILD_TYPE={cfg}",
"-DMLX_BUILD_PYTHON_BINDINGS=ON", f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}",
"-DMLX_BUILD_TESTS=OFF", "-DMLX_BUILD_TESTS=OFF",
"-DMLX_BUILD_BENCHMARKS=OFF", "-DMLX_BUILD_BENCHMARKS=OFF",
"-DMLX_BUILD_EXAMPLES=OFF", "-DMLX_BUILD_EXAMPLES=OFF",
@ -99,10 +113,6 @@ class CMakeBuild(build_ext):
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
build_args += [f"-j{os.cpu_count()}"] build_args += [f"-j{os.cpu_count()}"]
build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists():
build_temp.mkdir(parents=True)
subprocess.run( subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
) )
@ -158,26 +168,40 @@ class GenerateStubs(Command):
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
class MLXBdistWheel(bdist_wheel):
def get_tag(self) -> tuple[str, str, str]:
impl, abi, plat_name = super().get_tag()
if build_stage == 2:
impl = self.python_tag
abi = "none"
return (impl, abi, plat_name)
# Read the content of README.md # Read the content of README.md
with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
# The information here can also be placed in setup.cfg - better separation of
# logic and declaration, and simpler if you include description/version in a file.
if __name__ == "__main__": if __name__ == "__main__":
packages = find_namespace_packages(
where="python", exclude=["src", "tests", "tests.*"]
)
package_dir = {"": "python"} package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} packages = [
install_requires = [] "mlx",
"mlx.nn",
"mlx.nn.layers",
"mlx.optimizers",
]
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
install_requires = []
if build_cuda: if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version()
setup( _setup = partial(
name="mlx-cuda" if build_cuda else "mlx", setup,
version=get_version(), version=version,
author="MLX Contributors", author="MLX Contributors",
author_email="mlx@group.apple.com", author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.", description="A framework for machine learning on Apple silicon.",
@ -185,29 +209,77 @@ if __name__ == "__main__":
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="MIT", license="MIT",
url="https://github.com/ml-explore/mlx", url="https://github.com/ml-explore/mlx",
packages=packages,
package_dir=package_dir,
package_data=package_data,
include_package_data=True, include_package_data=True,
install_requires=install_requires, package_dir=package_dir,
extras_require={
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=42",
"torch",
"typing_extensions",
],
},
entry_points={
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
},
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
zip_safe=False, zip_safe=False,
python_requires=">=3.9", python_requires=">=3.9",
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={
"build_ext": CMakeBuild,
"generate_stubs": GenerateStubs,
"bdist_wheel": MLXBdistWheel,
},
) )
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
extras = {
"dev": [
"nanobind==2.4.0",
"numpy",
"pre-commit",
"setuptools>=80",
"torch",
"typing_extensions",
],
}
entry_points = {
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
}
# Release builds for PyPi are in two stages.
# Each stage should be run from a clean build:
# python setup.py clean --all
#
# Stage 1:
# - Triggered with `MLX_BUILD_STAGE=1`
# - Include everything except backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc)
# - Wheel has Python ABI and platform tags
# - Wheel should be built for the cross-product of python version and platforms
# - Package name is mlx and it depends on subpackage in stage 2 (e.g. mlx-metal)
# Stage 2:
# - Triggered with `MLX_BUILD_STAGE=2`
# - Includes only backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc)
# - Wheel has only platform tags
# - Wheel should be built only for different platforms
# - Package name is back-end specific, e.g mlx-metal
if build_stage != 2:
if build_stage == 1:
if build_macos:
install_requires += [f"mlx-metal=={version}"]
else:
extras["cuda"] = [f"mlx-cuda=={version}"]
extras["cpu"] = [f"mlx-cpu=={version}"]
_setup(
name="mlx",
packages=packages,
extras_require=extras,
entry_points=entry_points,
install_requires=install_requires,
package_data=package_data,
)
else:
if build_macos:
name = "mlx-metal"
elif build_cuda:
name = "mlx-cuda"
else:
name = "mlx-cpu"
_setup(
name=name,
packages=["mlx"],
)