From 6ad0889c8ae1b74df02d0a895d1de4d364445927 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 4 Aug 2025 15:33:05 -0700 Subject: [PATCH] default install cuda on linux (#2462) --- README.md | 11 +++-------- docs/src/install.rst | 4 ++-- setup.py | 12 +++++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 87dcfd18e6..070f00d019 100644 --- a/README.md +++ b/README.md @@ -75,16 +75,11 @@ macOS, run: pip install mlx ``` -To install the CUDA backend on Linux, run: +On Linux, the same command (``pip install mlx``) will install the CUDA backend +by default. To install a CPU-only Linux package, run: ```bash -pip install "mlx[cuda]" -``` - -To install a CPU-only Linux package, run: - -```bash -pip install "mlx[cpu]" +pip install mlx[cpu] ``` Checkout the diff --git a/docs/src/install.rst b/docs/src/install.rst index 268141567e..7d9f6d90de 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with: .. code-block:: shell - pip install "mlx[cuda]" + pip install mlx To install the CUDA package from PyPi your system must meet the following requirements: @@ -49,7 +49,7 @@ For a CPU-only version of MLX that runs on Linux use: .. code-block:: shell - pip install "mlx[cpu]" + pip install mlx[cpu] To install the CPU-only package from PyPi your system must meet the following requirements: diff --git a/setup.py b/setup.py index d15f36b2b7..76d7cf6f95 100644 --- a/setup.py +++ b/setup.py @@ -274,11 +274,13 @@ if __name__ == "__main__": # - Package name is back-end specific, e.g mlx-metal if build_stage != 2: if build_stage == 1: - install_requires.append( - f'mlx-metal=={version}; platform_system == "Darwin"' - ) - extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] - extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] + install_requires += [ + f'mlx-metal=={version}; platform_system == "Darwin"', + f'mlx-cuda=={version}; extra != "cpu" and platform_system == "linux"', + ] + extras["cpu"] = [ + f'mlx-cpu=={version}; extra == "cpu" and platform_system == "linux"' + ] _setup( name="mlx",