mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
fa89f0b150
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f | ||
|
|
7f8ba2a003 | ||
|
|
c28249b81a | ||
|
|
e74bcdc5e3 | ||
|
|
d8ed6c1aa3 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 |
@@ -75,8 +75,13 @@ macOS, run:
|
|||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
On Linux, the same command (``pip install mlx``) will install the CUDA backend
|
To install the CUDA backend on Linux, run:
|
||||||
by default. To install a CPU-only Linux package, run:
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
|
```
|
||||||
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install mlx[cpu]
|
pip install mlx[cpu]
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx
|
pip install mlx[cuda]
|
||||||
|
|
||||||
To install the CUDA package from PyPi your system must meet the following
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
requirements:
|
requirements:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ void Matmul::run_impl(
|
|||||||
matmul_desc_,
|
matmul_desc_,
|
||||||
a_desc_,
|
a_desc_,
|
||||||
b_desc_,
|
b_desc_,
|
||||||
out_desc_, // TODO should that be c_desc is it's set?
|
c ? c_desc_ : out_desc_,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
pref_,
|
pref_,
|
||||||
1,
|
1,
|
||||||
@@ -226,8 +226,10 @@ void Matmul::run_impl(
|
|||||||
|
|
||||||
void* workspace_ptr = nullptr;
|
void* workspace_ptr = nullptr;
|
||||||
if (heuristic_.workspaceSize > 0) {
|
if (heuristic_.workspaceSize > 0) {
|
||||||
|
// Ensure workspace is 256-byte aligned
|
||||||
|
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||||
array workspace(
|
array workspace(
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
allocator::malloc(nbytes),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
|||||||
@@ -6,17 +6,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttention::use_fallback(
|
|
||||||
const array& q,
|
|
||||||
const array& k,
|
|
||||||
const array& v,
|
|
||||||
bool has_mask,
|
|
||||||
bool has_arr_mask,
|
|
||||||
bool do_causal,
|
|
||||||
Stream s) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
@@ -53,7 +42,6 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU(ScaledDotProductAttention)
|
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
|
|||||||
1110
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
1110
mlx/backend/cuda/scaled_dot_product_attention.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
setup.py
12
setup.py
@@ -274,13 +274,11 @@ if __name__ == "__main__":
|
|||||||
# - Package name is back-end specific, e.g mlx-metal
|
# - Package name is back-end specific, e.g mlx-metal
|
||||||
if build_stage != 2:
|
if build_stage != 2:
|
||||||
if build_stage == 1:
|
if build_stage == 1:
|
||||||
install_requires += [
|
install_requires.append(
|
||||||
f'mlx-metal=={version}; platform_system == "Darwin"',
|
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||||
f'mlx-cuda=={version}; extra != "cpu" and platform_system == "Linux"',
|
)
|
||||||
]
|
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||||
extras["cpu"] = [
|
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||||
f'mlx-cpu=={version}; extra == "cpu" and platform_system == "Linux"'
|
|
||||||
]
|
|
||||||
|
|
||||||
_setup(
|
_setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
|
|||||||
Reference in New Issue
Block a user