Compare commits

...

10 Commits

Author SHA1 Message Date
Angelos Katharopoulos
a22d0bf273 Add stricter condition to matrix sdpa 2025-08-06 19:51:14 -07:00
Jagrit Digani
99d8de8445 Fix cudnn routing 2025-08-06 15:05:58 -07:00
Jagrit Digani
c66b76a8c8 Update routing 2025-08-06 15:01:15 -07:00
Jagrit Digani
f81edd184f Complete 2 pass sdpav 2025-08-06 13:57:40 -07:00
Jagrit Digani
7f8ba2a003 [WIP] 2 pass sdpav 2025-08-06 09:56:39 -07:00
Jagrit Digani
c28249b81a Add more nvtx range for debug 2025-08-06 09:56:39 -07:00
Jagrit Digani
e74bcdc5e3 Add sdpa file 2025-08-06 09:56:39 -07:00
Jagrit Digani
d8ed6c1aa3 Add base cudnn attention support 2025-08-06 09:56:39 -07:00
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
7 changed files with 1128 additions and 24 deletions

View File

@@ -75,8 +75,13 @@ macOS, run:
pip install mlx pip install mlx
``` ```
On Linux, the same command (``pip install mlx``) will install the CUDA backend To install the CUDA backend on Linux, run:
by default. To install a CPU-only Linux package, run:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash ```bash
pip install mlx[cpu] pip install mlx[cpu]

View File

@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:

View File

@@ -39,6 +39,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu

View File

@@ -213,7 +213,7 @@ void Matmul::run_impl(
matmul_desc_, matmul_desc_,
a_desc_, a_desc_,
b_desc_, b_desc_,
out_desc_, // TODO should that be c_desc is it's set? c ? c_desc_ : out_desc_,
out_desc_, out_desc_,
pref_, pref_,
1, 1,
@@ -226,8 +226,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
allocator::malloc(heuristic_.workspaceSize), allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);

View File

@@ -6,17 +6,6 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)
} // namespace fast } // namespace fast

File diff suppressed because it is too large Load Diff

View File

@@ -274,13 +274,11 @@ if __name__ == "__main__":
# - Package name is back-end specific, e.g mlx-metal # - Package name is back-end specific, e.g mlx-metal
if build_stage != 2: if build_stage != 2:
if build_stage == 1: if build_stage == 1:
install_requires += [ install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"', f'mlx-metal=={version}; platform_system == "Darwin"'
f'mlx-cuda=={version}; extra != "cpu" and platform_system == "Linux"', )
] extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [ extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
f'mlx-cpu=={version}; extra == "cpu" and platform_system == "Linux"'
]
_setup( _setup(
name="mlx", name="mlx",