| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-04 16:04:11 -08:00
										 |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-12-19 12:26:04 +09:00
										 |  |  | import platform | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | import re | 
					
						
							|  |  |  | import subprocess | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | from functools import partial | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2023-12-04 16:04:11 -08:00
										 |  |  | from subprocess import run | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-22 06:30:17 -07:00
										 |  |  | from setuptools import Command, Extension, find_namespace_packages, setup | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | from setuptools.command.bdist_wheel import bdist_wheel | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | from setuptools.command.build_ext import build_ext | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-24 12:47:05 -07:00
										 |  |  | def get_version(): | 
					
						
							|  |  |  |     with open("mlx/version.h", "r") as fid: | 
					
						
							|  |  |  |         for l in fid: | 
					
						
							|  |  |  |             if "#define MLX_VERSION_MAJOR" in l: | 
					
						
							|  |  |  |                 major = l.split()[-1] | 
					
						
							|  |  |  |             if "#define MLX_VERSION_MINOR" in l: | 
					
						
							|  |  |  |                 minor = l.split()[-1] | 
					
						
							|  |  |  |             if "#define MLX_VERSION_PATCH" in l: | 
					
						
							|  |  |  |                 patch = l.split()[-1] | 
					
						
							|  |  |  |     version = f"{major}.{minor}.{patch}" | 
					
						
							| 
									
										
										
										
											2023-12-04 16:04:11 -08:00
										 |  |  |     if "PYPI_RELEASE" not in os.environ: | 
					
						
							|  |  |  |         today = datetime.date.today() | 
					
						
							| 
									
										
										
										
											2024-01-18 12:00:24 -08:00
										 |  |  |         version = f"{version}.dev{today.year}{today.month:02d}{today.day:02d}" | 
					
						
							| 
									
										
										
										
											2023-12-04 16:04:11 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if "DEV_RELEASE" not in os.environ: | 
					
						
							|  |  |  |             git_hash = ( | 
					
						
							|  |  |  |                 run( | 
					
						
							|  |  |  |                     "git rev-parse --short HEAD".split(), | 
					
						
							|  |  |  |                     capture_output=True, | 
					
						
							|  |  |  |                     check=True, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 .stdout.strip() | 
					
						
							|  |  |  |                 .decode() | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             version = f"{version}+{git_hash}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return version | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | build_stage = int(os.environ.get("MLX_BUILD_STAGE", 0)) | 
					
						
							| 
									
										
										
										
											2025-07-28 12:35:15 -07:00
										 |  |  | build_macos = platform.system() == "Darwin" | 
					
						
							|  |  |  | build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | # A CMakeExtension needs a sourcedir instead of a file list. | 
					
						
							|  |  |  | # The name must be the _single_ output extension from the CMake build. | 
					
						
							|  |  |  | # If you need multiple extensions, see scikit-build. | 
					
						
							|  |  |  | class CMakeExtension(Extension): | 
					
						
							|  |  |  |     def __init__(self, name: str, sourcedir: str = "") -> None: | 
					
						
							|  |  |  |         super().__init__(name, sources=[]) | 
					
						
							|  |  |  |         self.sourcedir = os.fspath(Path(sourcedir).resolve()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CMakeBuild(build_ext): | 
					
						
							|  |  |  |     def build_extension(self, ext: CMakeExtension) -> None: | 
					
						
							|  |  |  |         # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ | 
					
						
							|  |  |  |         ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)  # type: ignore[no-untyped-call] | 
					
						
							|  |  |  |         extdir = ext_fullpath.parent.resolve() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug | 
					
						
							|  |  |  |         cfg = "Debug" if debug else "Release" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |         build_temp = Path(self.build_temp) / ext.name | 
					
						
							|  |  |  |         if not build_temp.exists(): | 
					
						
							|  |  |  |             build_temp.mkdir(parents=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         build_python = "ON" | 
					
						
							|  |  |  |         install_prefix = f"{extdir}{os.sep}" | 
					
						
							|  |  |  |         if build_stage == 1: | 
					
						
							|  |  |  |             # Don't include MLX libraries in the wheel | 
					
						
							|  |  |  |             install_prefix = f"{build_temp}" | 
					
						
							|  |  |  |         elif build_stage == 2: | 
					
						
							|  |  |  |             # Don't include Python bindings in the wheel | 
					
						
							|  |  |  |             build_python = "OFF" | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         cmake_args = [ | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |             f"-DCMAKE_INSTALL_PREFIX={install_prefix}", | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |             f"-DCMAKE_BUILD_TYPE={cfg}", | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |             f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}", | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |             "-DMLX_BUILD_TESTS=OFF", | 
					
						
							|  |  |  |             "-DMLX_BUILD_BENCHMARKS=OFF", | 
					
						
							|  |  |  |             "-DMLX_BUILD_EXAMPLES=OFF", | 
					
						
							|  |  |  |             f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}", | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2025-07-28 12:35:15 -07:00
										 |  |  |         if build_stage == 2 and build_cuda: | 
					
						
							|  |  |  |             # Last arch is always real and virtual for forward-compatibility | 
					
						
							|  |  |  |             cuda_archs = ";".join(("70-real", "80-real", "90-real", "100-real", "120")) | 
					
						
							|  |  |  |             cmake_args += [f"-DMLX_CUDA_ARCHITECTURES={cuda_archs}"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-10 12:07:50 +09:00
										 |  |  |         # Some generators require explcitly passing config when building. | 
					
						
							|  |  |  |         build_args = ["--config", cfg] | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         # Adding CMake arguments set as environment variable | 
					
						
							|  |  |  |         # (needed e.g. to build for ARM OSx on conda-forge) | 
					
						
							|  |  |  |         if "CMAKE_ARGS" in os.environ: | 
					
						
							|  |  |  |             cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Pass version to C++ | 
					
						
							|  |  |  |         cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"]  # type: ignore[attr-defined] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-28 12:35:15 -07:00
										 |  |  |         if build_macos: | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |             # Cross-compile support for macOS - respect ARCHFLAGS if set | 
					
						
							|  |  |  |             archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) | 
					
						
							|  |  |  |             if archs: | 
					
						
							|  |  |  |                 cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] | 
					
						
							| 
									
										
										
										
											2024-12-19 12:26:04 +09:00
										 |  |  |         if platform.system() == "Windows": | 
					
						
							|  |  |  |             # On Windows DLLs must be put in the same dir with the extension | 
					
						
							|  |  |  |             # while cmake puts mlx.dll into the "bin" sub-dir. Link with mlx | 
					
						
							|  |  |  |             # statically to work around it. | 
					
						
							|  |  |  |             cmake_args += ["-DBUILD_SHARED_LIBS=OFF"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             cmake_args += ["-DBUILD_SHARED_LIBS=ON"] | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level | 
					
						
							|  |  |  |         # across all generators. | 
					
						
							|  |  |  |         if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: | 
					
						
							| 
									
										
										
										
											2025-07-07 22:06:45 +09:00
										 |  |  |             build_args += [f"-j{os.cpu_count()}"] | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-29 08:43:22 +09:00
										 |  |  |         # Avoid cache miss when building from temporary dirs. | 
					
						
							| 
									
										
										
										
											2025-09-09 07:41:05 +09:00
										 |  |  |         os.environ["CCACHE_BASEDIR"] = os.path.realpath(self.build_temp) | 
					
						
							|  |  |  |         os.environ["CCACHE_NOHASHDIR"] = "true" | 
					
						
							| 
									
										
										
										
											2025-07-29 08:43:22 +09:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         subprocess.run( | 
					
						
							|  |  |  |             ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         subprocess.run( | 
					
						
							|  |  |  |             ["cmake", "--build", ".", "--target", "install", *build_args], | 
					
						
							|  |  |  |             cwd=build_temp, | 
					
						
							|  |  |  |             check=True, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Make sure to copy mlx.metallib for inplace builds | 
					
						
							|  |  |  |     def run(self): | 
					
						
							|  |  |  |         super().run() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 | 
					
						
							|  |  |  |         if self.inplace: | 
					
						
							|  |  |  |             for ext in self.extensions: | 
					
						
							|  |  |  |                 if ext.name == "mlx.core": | 
					
						
							|  |  |  |                     # Resolve inplace package dir | 
					
						
							|  |  |  |                     build_py = self.get_finalized_command("build_py") | 
					
						
							|  |  |  |                     inplace_file, regular_file = self._get_inplace_equivalent( | 
					
						
							|  |  |  |                         build_py, ext | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     inplace_dir = str(Path(inplace_file).parent.resolve()) | 
					
						
							|  |  |  |                     regular_dir = str(Path(regular_file).parent.resolve()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     self.copy_tree(regular_dir, inplace_dir) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-14 15:58:45 -05:00
										 |  |  | class GenerateStubs(Command): | 
					
						
							|  |  |  |     user_options = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def initialize_options(self): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def finalize_options(self): | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def run(self) -> None: | 
					
						
							| 
									
										
										
										
											2024-03-24 15:03:27 -07:00
										 |  |  |         out_path = "python/mlx/core" | 
					
						
							|  |  |  |         stub_cmd = [ | 
					
						
							|  |  |  |             "python", | 
					
						
							|  |  |  |             "-m", | 
					
						
							|  |  |  |             "nanobind.stubgen", | 
					
						
							|  |  |  |             "-m", | 
					
						
							|  |  |  |             "mlx.core", | 
					
						
							| 
									
										
										
										
											2024-10-15 16:23:37 -07:00
										 |  |  |             "-p", | 
					
						
							|  |  |  |             "python/mlx/_stub_patterns.txt", | 
					
						
							| 
									
										
										
										
											2024-03-24 15:03:27 -07:00
										 |  |  |         ] | 
					
						
							|  |  |  |         subprocess.run(stub_cmd + ["-r", "-O", out_path]) | 
					
						
							|  |  |  |         # Run again without recursive to specify output file name | 
					
						
							|  |  |  |         subprocess.run(["rm", f"{out_path}/mlx.pyi"]) | 
					
						
							|  |  |  |         subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) | 
					
						
							| 
									
										
										
										
											2023-12-14 15:58:45 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | class MLXBdistWheel(bdist_wheel): | 
					
						
							|  |  |  |     def get_tag(self) -> tuple[str, str, str]: | 
					
						
							|  |  |  |         impl, abi, plat_name = super().get_tag() | 
					
						
							|  |  |  |         if build_stage == 2: | 
					
						
							|  |  |  |             impl = self.python_tag | 
					
						
							|  |  |  |             abi = "none" | 
					
						
							|  |  |  |         return (impl, abi, plat_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 17:03:29 +05:30
										 |  |  | # Read the content of README.md | 
					
						
							|  |  |  | with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: | 
					
						
							|  |  |  |     long_description = f.read() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     package_dir = {"": "python"} | 
					
						
							| 
									
										
										
										
											2025-07-22 06:30:17 -07:00
										 |  |  |     packages = find_namespace_packages( | 
					
						
							|  |  |  |         where="python", | 
					
						
							|  |  |  |         exclude=[ | 
					
						
							|  |  |  |             "src", | 
					
						
							|  |  |  |             "tests", | 
					
						
							|  |  |  |             "scripts", | 
					
						
							|  |  |  |             "mlx.lib", | 
					
						
							|  |  |  |             "mlx.include", | 
					
						
							|  |  |  |             "mlx.share", | 
					
						
							|  |  |  |             "mlx.share.**", | 
					
						
							|  |  |  |             "mlx.include.**", | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     version = get_version() | 
					
						
							| 
									
										
										
										
											2023-12-04 16:04:11 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |     _setup = partial( | 
					
						
							|  |  |  |         setup, | 
					
						
							|  |  |  |         version=version, | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         author="MLX Contributors", | 
					
						
							|  |  |  |         author_email="mlx@group.apple.com", | 
					
						
							| 
									
										
										
										
											2023-12-14 11:48:00 +08:00
										 |  |  |         description="A framework for machine learning on Apple silicon.", | 
					
						
							| 
									
										
										
										
											2023-12-08 17:03:29 +05:30
										 |  |  |         long_description=long_description, | 
					
						
							|  |  |  |         long_description_content_type="text/markdown", | 
					
						
							| 
									
										
										
										
											2025-06-19 15:26:36 -07:00
										 |  |  |         license="MIT", | 
					
						
							| 
									
										
										
										
											2024-01-12 13:34:16 -08:00
										 |  |  |         url="https://github.com/ml-explore/mlx", | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         include_package_data=True, | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |         package_dir=package_dir, | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |         zip_safe=False, | 
					
						
							| 
									
										
										
										
											2024-10-16 17:51:38 -07:00
										 |  |  |         python_requires=">=3.9", | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |         ext_modules=[CMakeExtension("mlx.core")], | 
					
						
							|  |  |  |         cmdclass={ | 
					
						
							|  |  |  |             "build_ext": CMakeBuild, | 
					
						
							|  |  |  |             "generate_stubs": GenerateStubs, | 
					
						
							|  |  |  |             "bdist_wheel": MLXBdistWheel, | 
					
						
							|  |  |  |         }, | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-22 06:30:17 -07:00
										 |  |  |     package_data = {"mlx.core": ["*.pyi"]} | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     extras = { | 
					
						
							|  |  |  |         "dev": [ | 
					
						
							|  |  |  |             "nanobind==2.4.0", | 
					
						
							|  |  |  |             "numpy", | 
					
						
							|  |  |  |             "pre-commit", | 
					
						
							|  |  |  |             "setuptools>=80", | 
					
						
							|  |  |  |             "torch", | 
					
						
							|  |  |  |             "typing_extensions", | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     entry_points = { | 
					
						
							|  |  |  |         "console_scripts": [ | 
					
						
							|  |  |  |             "mlx.launch = mlx.distributed_run:main", | 
					
						
							|  |  |  |             "mlx.distributed_config = mlx.distributed_run:distributed_config", | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2025-07-23 16:54:19 -07:00
										 |  |  |     install_requires = [] | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Release builds for PyPi are in two stages. | 
					
						
							|  |  |  |     # Each stage should be run from a clean build: | 
					
						
							|  |  |  |     #   python setup.py clean --all | 
					
						
							|  |  |  |     # | 
					
						
							|  |  |  |     # Stage 1: | 
					
						
							|  |  |  |     #  - Triggered with `MLX_BUILD_STAGE=1` | 
					
						
							|  |  |  |     #  - Include everything except backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) | 
					
						
							|  |  |  |     #  - Wheel has Python ABI and platform tags | 
					
						
							|  |  |  |     #  - Wheel should be built for the cross-product of python version and platforms | 
					
						
							|  |  |  |     #  - Package name is mlx and it depends on subpackage in stage 2 (e.g. mlx-metal) | 
					
						
							|  |  |  |     # Stage 2: | 
					
						
							|  |  |  |     #  - Triggered with `MLX_BUILD_STAGE=2` | 
					
						
							|  |  |  |     #  - Includes only backend-specific binaries (e.g. libmlx.so, mlx.metallib, etc) | 
					
						
							|  |  |  |     #  - Wheel has only platform tags | 
					
						
							|  |  |  |     #  - Wheel should be built only for different platforms | 
					
						
							|  |  |  |     #  - Package name is back-end specific, e.g mlx-metal | 
					
						
							|  |  |  |     if build_stage != 2: | 
					
						
							|  |  |  |         if build_stage == 1: | 
					
						
							| 
									
										
										
										
											2025-08-06 06:19:12 -07:00
										 |  |  |             install_requires.append( | 
					
						
							|  |  |  |                 f'mlx-metal=={version}; platform_system == "Darwin"' | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] | 
					
						
							|  |  |  |             extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  |         _setup( | 
					
						
							|  |  |  |             name="mlx", | 
					
						
							|  |  |  |             packages=packages, | 
					
						
							|  |  |  |             extras_require=extras, | 
					
						
							|  |  |  |             entry_points=entry_points, | 
					
						
							|  |  |  |             install_requires=install_requires, | 
					
						
							|  |  |  |             package_data=package_data, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         if build_macos: | 
					
						
							|  |  |  |             name = "mlx-metal" | 
					
						
							|  |  |  |         elif build_cuda: | 
					
						
							|  |  |  |             name = "mlx-cuda" | 
					
						
							| 
									
										
										
										
											2025-07-23 16:54:19 -07:00
										 |  |  |             install_requires += [ | 
					
						
							|  |  |  |                 "nvidia-cublas-cu12==12.9.*", | 
					
						
							|  |  |  |                 "nvidia-cuda-nvrtc-cu12==12.9.*", | 
					
						
							| 
									
										
										
										
											2025-07-25 15:20:29 -07:00
										 |  |  |                 "nvidia-cudnn-cu12==9.*", | 
					
						
							| 
									
										
										
										
											2025-08-21 17:57:49 -07:00
										 |  |  |                 "nvidia-nccl-cu12", | 
					
						
							| 
									
										
										
										
											2025-07-23 16:54:19 -07:00
										 |  |  |             ] | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |         else: | 
					
						
							|  |  |  |             name = "mlx-cpu" | 
					
						
							|  |  |  |         _setup( | 
					
						
							|  |  |  |             name=name, | 
					
						
							|  |  |  |             packages=["mlx"], | 
					
						
							| 
									
										
										
										
											2025-07-23 16:54:19 -07:00
										 |  |  |             install_requires=install_requires, | 
					
						
							| 
									
										
										
										
											2025-07-14 17:17:33 -07:00
										 |  |  |         ) |