Compare commits

...

2 Commits

Author SHA1 Message Date
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
4 changed files with 17 additions and 12 deletions

View File

@@ -75,8 +75,13 @@ macOS, run:
pip install mlx
```
On Linux, the same command (``pip install mlx``) will install the CUDA backend
by default. To install a CPU-only Linux package, run:
To install the CUDA backend on Linux, run:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]

View File

@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:

View File

@@ -213,7 +213,7 @@ void Matmul::run_impl(
matmul_desc_,
a_desc_,
b_desc_,
out_desc_, // TODO should that be c_desc is it's set?
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
@@ -226,8 +226,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(heuristic_.workspaceSize),
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);

View File

@@ -274,13 +274,11 @@ if __name__ == "__main__":
# - Package name is back-end specific, e.g mlx-metal
if build_stage != 2:
if build_stage == 1:
install_requires += [
f'mlx-metal=={version}; platform_system == "Darwin"',
f'mlx-cuda=={version}; extra != "cpu" and platform_system == "Linux"',
]
extras["cpu"] = [
f'mlx-cpu=={version}; extra == "cpu" and platform_system == "Linux"'
]
install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"'
)
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup(
name="mlx",