mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
fa89f0b150
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f | ||
|
|
7f8ba2a003 | ||
|
|
c28249b81a | ||
|
|
e74bcdc5e3 | ||
|
|
d8ed6c1aa3 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 |
@@ -75,8 +75,13 @@ macOS, run:
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
On Linux, the same command (``pip install mlx``) will install the CUDA backend
|
||||
by default. To install a CPU-only Linux package, run:
|
||||
To install the CUDA backend on Linux, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
|
||||
@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
@@ -39,6 +39,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
|
||||
@@ -213,7 +213,7 @@ void Matmul::run_impl(
|
||||
matmul_desc_,
|
||||
a_desc_,
|
||||
b_desc_,
|
||||
out_desc_, // TODO should that be c_desc is it's set?
|
||||
c ? c_desc_ : out_desc_,
|
||||
out_desc_,
|
||||
pref_,
|
||||
1,
|
||||
@@ -226,8 +226,10 @@ void Matmul::run_impl(
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
// Ensure workspace is 256-byte aligned
|
||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||
array workspace(
|
||||
allocator::malloc(heuristic_.workspaceSize),
|
||||
allocator::malloc(nbytes),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
||||
@@ -6,17 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
bool fast::ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
|
||||
1110
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
1110
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
setup.py
12
setup.py
@@ -274,13 +274,11 @@ if __name__ == "__main__":
|
||||
# - Package name is back-end specific, e.g mlx-metal
|
||||
if build_stage != 2:
|
||||
if build_stage == 1:
|
||||
install_requires += [
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"',
|
||||
f'mlx-cuda=={version}; extra != "cpu" and platform_system == "Linux"',
|
||||
]
|
||||
extras["cpu"] = [
|
||||
f'mlx-cpu=={version}; extra == "cpu" and platform_system == "Linux"'
|
||||
]
|
||||
install_requires.append(
|
||||
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||
)
|
||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||
|
||||
_setup(
|
||||
name="mlx",
|
||||
|
||||
Reference in New Issue
Block a user