diff --git a/.circleci/config.yml b/.circleci/config.yml index be5f7aac5..341655c93 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -336,10 +336,11 @@ jobs: pip install typing_extensions python setup.py generate_stubs << parameters.extra_env >> python -m build --wheel + << parameters.extra_env >> MLX_BUILD_COMMON=1 python -m build --wheel auditwheel show dist/* auditwheel repair dist/* --plat manylinux_2_31_x86_64 - run: - name: Upload package + name: Upload packages command: | source env/bin/activate twine upload wheelhouse/* diff --git a/docs/src/install.rst b/docs/src/install.rst index a50b6a71d..7c1a02b62 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -38,8 +38,16 @@ and SM 7.0 (Volta) and up. To install MLX with CUDA support, run: .. code-block:: shell - pip install mlx-cuda + pip install "mlx[cuda]" +CPU only (Linux) +^^^^^^^^^^^^^^^^ + +For a CPU-only version of MLX that runs on Linux use: + +.. code-block:: shell + + pip install "mlx[cpu]" Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 770718e25..9d2e44416 100644 --- a/setup.py +++ b/setup.py @@ -5,10 +5,11 @@ import os import platform import re import subprocess +from functools import partial from pathlib import Path from subprocess import run -from setuptools import Command, Extension, find_namespace_packages, setup +from setuptools import Command, Extension, setup from setuptools.command.build_ext import build_ext @@ -165,19 +166,27 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: # The information here can also be placed in setup.cfg - better separation of # logic and declaration, and simpler if you include description/version in a file. if __name__ == "__main__": - packages = find_namespace_packages( - where="python", exclude=["src", "tests", "tests.*"] - ) package_dir = {"": "python"} package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} - install_requires = [] + packages = [ + "mlx", + "mlx.nn", + "mlx.optimizers", + ] + + is_release = "PYPI_RELEASE" in os.environ + build_macos = platform.system() == "Darwin" build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") + build_common = "MLX_BUILD_COMMON" in os.environ + + install_requires = [] if build_cuda: install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"] + version = get_version() - setup( - name="mlx-cuda" if build_cuda else "mlx", - version=get_version(), + _setup = partial( + setup, + version=version, author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", @@ -185,29 +194,56 @@ if __name__ == "__main__": long_description_content_type="text/markdown", license="MIT", url="https://github.com/ml-explore/mlx", - packages=packages, package_dir=package_dir, package_data=package_data, - include_package_data=True, - install_requires=install_requires, - extras_require={ - "dev": [ - "nanobind==2.4.0", - "numpy", - "pre-commit", - "setuptools>=42", - "torch", - "typing_extensions", - ], - }, - entry_points={ - "console_scripts": [ - "mlx.launch = mlx.distributed_run:main", - "mlx.distributed_config = mlx.distributed_run:distributed_config", - ] - }, - ext_modules=[CMakeExtension("mlx.core")], - cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, zip_safe=False, python_requires=">=3.9", + install_requires=install_requires, ) + + extras = { + "dev": [ + "nanobind==2.4.0", + "numpy", + "pre-commit", + "setuptools>=42", + "torch", + "typing_extensions", + ], + } + entry_points = { + "console_scripts": [ + "mlx.launch = mlx.distributed_run:main", + "mlx.distributed_config = mlx.distributed_run:distributed_config", + ] + } + + if not is_release or build_macos: + _setup( + name="mlx", + include_package_data=True, + packages=packages, + extras_require=extras, + entry_points=entry_points, + ext_modules=[CMakeExtension("mlx.core")], + cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, + ) + elif build_common: + extras["cpu"] = [f"mlx-cpu=={version}"] + extras["cuda"] = [f"mlx-cuda=={version}"] + _setup( + name="mlx", + include_package_data=False, + packages=["mlx"], + extras_require=extras, + entry_points=entry_points, + ) + else: + _setup( + name="mlx-cuda" if build_cuda else "mlx-cpu", + include_package_data=True, + packages=packages, + extras_require=extras, + ext_modules=[CMakeExtension("mlx.core")], + cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, + )