Compare commits

..

20 Commits

Author SHA1 Message Date
Awni Hannun
dcb8319f3d update install docs and requirements (#2419) 2025-07-25 12:13:19 -07:00
Awni Hannun
5597fa089c Fix qvm splitk (#2415) 2025-07-25 11:50:24 -07:00
Awni Hannun
9acec364c2 [CUDA] Always use batched matmul (#2404)
* cuda batched mm

* addmm as well

* comment
2025-07-24 20:46:02 -07:00
Skonor
7d9d6ef456 docs: fix adam and adamw eps placement (#2416)
Co-authored-by: Mikhail Gorbunov <m_gorbunov@apple.com>
2025-07-24 16:40:45 -07:00
Cheng
6f5874a2f2 [CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN

* Initial implementation

* Remove backend apis

* Fix recording cudnn conv

* More unused backend apis

* Fix C++ conv tests

* include cudnn as python dep

* Install libcudnn9-dev-cuda-12 in CI

* cudnn only accepts contiguous inputs

* Switch to backend apis

* Plan needs to be kept alive

* Turn off tf32

* Add cache

* Test the native cuda graph api

* Set cudnn stream before execution

* Make LRUCache more like a normal container

* Do error check for cublas handle

* Zero-initilizing array

* Use tf32 for conv

* Skip TestConv.test_torch_conv_2D test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-07-25 08:12:10 +09:00
Awni Hannun
70dc336785 Test on cuda 12.2 and 12.9 (#2413) 2025-07-24 06:06:15 -07:00
Awni Hannun
4e504039f5 [Metal] Release metal events (#2412)
* release metal events

* fix

* fix
2025-07-23 19:53:42 -07:00
Awni Hannun
d1f4d291e8 Fix uv install and add dev release (#2411)
* fix uv install and add dev release

* fix docstring

* pin cuda deps

* cuda release on cpu-only machine
2025-07-23 16:54:19 -07:00
Awni Hannun
e1840853ce full row mask in sdpa consistently gives nan (#2406) 2025-07-23 16:37:03 -07:00
Cheng
0f5ce173da [CUDA] --compress-mode requires CUDA 12.8 (#2407) 2025-07-23 06:11:11 -07:00
Cheng
588854195f Remove unused code in Convolution::vjp (#2408) 2025-07-23 06:11:00 -07:00
Fangjun Kuang
28d068bce6 Fix an error in the comment for mx.dequantize (#2409) 2025-07-23 06:10:50 -07:00
Awni Hannun
d107d8d495 add cuda gemv (#2400) 2025-07-22 08:24:13 -07:00
Awni Hannun
1e496ddb82 [CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
2025-07-22 08:24:01 -07:00
Awni Hannun
74eccbf3fa use size option in binary (#2399) 2025-07-22 07:00:53 -07:00
Awni Hannun
08638223ca Fix including stubs in wheel (#2398)
* fix including stubs in wheel

* fix bool_
2025-07-22 06:30:17 -07:00
Cheng
56cc858af9 Add contiguous_copy_cpu util for copying array (#2397) 2025-07-21 07:30:35 -07:00
Cheng
f55c4ed1d6 Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Awni Hannun
93d70419e7 [CUDA] speedup handling scalars (#2389)
* speedup scalars in cuda

* comment
2025-07-18 21:47:31 -07:00
Awni Hannun
63f663d9c6 fix cuda manylinux version to match others (#2388) 2025-07-18 21:02:16 -07:00
60 changed files with 1938 additions and 816 deletions

View File

@@ -7,6 +7,9 @@ parameters:
nightly_build: nightly_build:
type: boolean type: boolean
default: false default: false
test_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -200,8 +203,12 @@ jobs:
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test: cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine: machine:
image: linux-cuda-12:2023.11.1 image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout
@@ -209,6 +216,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
sudo apt-get update sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env python3 -m venv env
source env/bin/activate source env/bin/activate
@@ -366,22 +374,27 @@ jobs:
type: string type: string
default: "" default: ""
machine: machine:
image: linux-cuda-12:default image: ubuntu-2204:current
resource_class: gpu.nvidia.small.gen2 resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
name: Build wheel name: Build wheel
command: | command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip sudo apt-get install zip
python -m venv env
source env/bin/activate
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \ << parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w python -m build -w
@@ -392,7 +405,6 @@ jobs:
- run: - run:
name: Upload package name: Upload package
command: | command: |
source env/bin/activate
twine upload wheelhouse/*.whl twine upload wheelhouse/*.whl
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -405,19 +417,24 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "14.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test - cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation - build_documentation
build_pypi_release: build_pypi_release:
when: when:
and: and:
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- build_release: - build_release:
filters: filters:
@@ -601,3 +618,87 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release - build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
@@ -68,18 +68,23 @@ in the documentation.
## Installation ## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run: MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
**With `pip`**: ```bash
```
pip install mlx pip install mlx
``` ```
**With `conda`**: To install the CUDA backend on Linux, run:
```bash
pip install "mlx[cuda]"
``` ```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install "mlx[cpu]"
``` ```
Checkout the Checkout the

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx pip install mlx
To install from PyPI you must meet the following requirements: To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.9 - Using a native Python >= 3.9
@@ -26,13 +26,22 @@ To install from PyPI you must meet the following requirements:
CUDA CUDA
^^^^ ^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12 MLX has a CUDA backend which you can install with:
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell .. code-block:: shell
pip install "mlx[cuda]" pip install "mlx[cuda]"
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
@@ -42,6 +51,13 @@ For a CPU-only version of MLX that runs on Linux use:
pip install "mlx[cpu]" pip install "mlx[cpu]"
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -377,4 +377,10 @@ void copy_cpu_inplace(
}); });
} }
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -30,4 +30,7 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt, const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt); const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return {arr, false}; return {arr, false};
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return {contiguous_copy_cpu(arr, stream), true};
copy_cpu(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
} }
}; };
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
} }
return in; return in;
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy = contiguous_copy_cpu(in, s);
copy_cpu(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy); out.copy_shared_buffer(arr_copy);
return arr_copy; return arr_copy;
} }

View File

@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -136,9 +136,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return std::make_tuple(true, sty, arr, false); return std::make_tuple(true, sty, arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true); return std::make_tuple(false, stx, arr_copy, true);
} }
}; };

View File

@@ -712,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return std::make_pair(arr, false); return std::make_pair(arr, false);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); return std::make_pair(contiguous_copy_cpu(arr, s), true);
copy_cpu(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
} }
}; };

View File

@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity // Ensure contiguity
auto in = inputs[0]; auto in = inputs[0];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_cpu(in, stream());
copy_cpu(in, arr_copy, CopyType::General, stream()); encoder.add_temporary(in);
in = arr_copy;
encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_cpu(x, s);
copy_cpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@@ -15,11 +15,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
@@ -45,6 +48,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation. # Embed kernel sources in binary for JIT compilation.
@@ -87,6 +98,13 @@ endif()
target_compile_options( target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with # Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain. # managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES set(MLX_CUDA_ARCHITECTURES
@@ -123,6 +141,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs. # Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Use the frontend APIs of cuDNN.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
target_link_libraries(mlx PRIVATE cudnn_frontend)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -17,14 +16,66 @@ namespace cu {
constexpr int page_size = 16384; constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
}
CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
}
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -36,7 +87,9 @@ Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size; auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size < page_size) { if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size); size = next_power_of_2(size);
} else { } else {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
@@ -44,19 +97,25 @@ Buffer CudaAllocator::malloc(size_t size) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache.
// try to reclaim memory from the cache. int64_t mem_to_free =
size_t mem_required = get_active_memory() + get_cache_memory() + size; get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_required >= memory_limit_) { if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_); buffer_cache_.release_cached_buffers(mem_to_free);
} }
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
}
lock.unlock(); lock.unlock();
buf = new CudaBuffer{nullptr, size}; if (!buf) {
cudaError_t err = cudaMallocManaged(&buf->data, size); buf = new CudaBuffer{nullptr, size};
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { cudaError_t err = cudaMallocManaged(&buf->data, size);
throw std::runtime_error(fmt::format( if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
} }
lock.lock(); lock.lock();
} }
@@ -67,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
return Buffer{buf}; return Buffer{buf};
} }
@@ -82,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); cuda_free(buf);
cuda_free(buf->data);
delete buf;
} }
} }
@@ -96,27 +152,14 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size; return buf->size;
} }
void CudaAllocator::register_this_thread() { // This must be called with mutex_ aquired
std::lock_guard lock(worker_mutex_); void CudaAllocator::cuda_free(CudaBuffer* buf) {
allowed_threads_.insert(std::this_thread::get_id()); if (scalar_pool_.in_pool(buf)) {
} scalar_pool_.free(buf);
} else {
void CudaAllocator::cuda_free(void* buf) { cudaFree(buf->data);
// If cuda_free() is called from a unregistered thread, reschedule the call to delete buf;
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
} }
cudaFree(buf);
} }
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {

View File

@@ -7,13 +7,10 @@
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <thread>
#include <utility> #include <utility>
namespace mlx::core::cu { namespace mlx::core::cu {
class Worker;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@@ -22,21 +19,35 @@ struct CudaBuffer {
size_t size; size_t size;
}; };
class SmallSizePool {
private:
union Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@@ -47,19 +58,18 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_; size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -115,7 +115,7 @@ __global__ void arg_reduce_general(
T vals[N_READS]; T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init); tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }

View File

@@ -128,7 +128,7 @@ __global__ void binary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d( auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim); index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]); out[index] = Op{}(a[a_idx], b[b_idx]);
} }

View File

@@ -160,7 +160,7 @@ __global__ void binary_two_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d( auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim); index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]); auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0]; out_a[index] = out[0];

340
mlx/backend/cuda/conv.cpp Normal file
View File

@@ -0,0 +1,340 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
#include <numeric>
namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
struct ConvCacheKey {
int device_id;
cudnnBackendDescriptorType_t backend_type;
cudnnDataType_t cudnn_type;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape;
std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation;
int groups;
uint8_t input_alignment;
uint8_t filter_alignment;
uint8_t output_alignment;
};
auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
/* capacity */ 128);
return cache;
}
template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
template <typename T>
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
const array& in,
const array& wt,
array& out) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()),
const_cast<void*>(wt.data<void>()),
out.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key,
const std::string& op_graph_tag,
const array& in,
const array& wt,
array& out) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, in, wt, out)) {
conv_cache().emplace(cache_key, std::move(plan));
return true;
}
} catch (cudnn_frontend::cudnnException&) {
}
}
return false;
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// cuDNN requires contiguous input.
// TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(out);
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),
backend_type,
cudnn_type,
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_strides_),
fixed_vector(kernel_dilation_),
groups_,
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) {
throw std::runtime_error("Cached convolution plan failed to execute.");
}
return;
}
// Build operation graph.
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
? CUDNN_DATA_FLOAT
: cudnn_type;
auto stride = convert_vector<int64_t>(kernel_strides_);
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
auto dilation = convert_vector<int64_t>(kernel_dilation_);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
auto op_graph_tag = op_graph.getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return;
}
throw std::runtime_error("Unable to find an engine for convolution.");
}
} // namespace mlx::core

View File

@@ -37,7 +37,7 @@ __global__ void copy_gg(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d( auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim); index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]); out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
} }

View File

@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
const int64_t* offset_out) { const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_4d( auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim); index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]); out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
} }

View File

@@ -34,7 +34,7 @@ __global__ void copy_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim); IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]); out[index] = CastOp<In, Out>{}(in[idx_in]);
} }
} }

View File

@@ -9,12 +9,23 @@
#include <future> #include <future>
#include <unordered_set> #include <unordered_set>
namespace mlx::core { namespace mlx::core::cu {
namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER // Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255 // This should be less than 255
constexpr int default_max_nodes_per_graph = 20; constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
int cuda_graph_cache_size() { int cuda_graph_cache_size() {
static int cache_size = []() { static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
@@ -22,7 +33,7 @@ int cuda_graph_cache_size() {
return cache_size; return cache_size;
} }
namespace cu { } // namespace
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute( CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
@@ -40,11 +51,14 @@ Device::Device(int device) : device_(device) {
} }
// The cublasLt handle is used by matmul. // The cublasLt handle is used by matmul.
make_current(); make_current();
cublasLtCreate(&lt_); CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
} }
Device::~Device() { Device::~Device() {
cublasLtDestroy(lt_); CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
} }
void Device::make_current() { void Device::make_current() {
@@ -66,29 +80,36 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
} }
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) {
return;
}
// Extract and add as single kernel node when possible.
size_t num_nodes; size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) { if (num_nodes == 1) {
cudaGraphNode_t captured_node; cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes)); CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
CUDA_KERNEL_NODE_PARAMS params; cudaGraphNodeType type;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params)); CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
cudaGraphNode_t node; if (type == cudaGraphNodeTypeKernel) {
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, &params)); CUDA_KERNEL_NODE_PARAMS params;
enc.insert_graph_dependencies(GraphNode{node, 'K'}); CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
} else { enc.add_kernel_node(params);
cudaGraphNode_t node; return;
CHECK_CUDA_ERROR( }
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
enc.insert_graph_dependencies(GraphNode{node, 'G'});
} }
CHECK_CUDA_ERROR(cudaGraphDestroy(graph)); // Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph);
} }
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
@@ -221,10 +242,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim; kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
cudaGraphNode_t node; add_kernel_node(kernel_params);
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
} }
void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(
@@ -241,12 +259,27 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y; kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z; kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
CUgraphNode node; add_kernel_node(kernel_params);
CHECK_CUDA_ERROR( }
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'}); insert_graph_dependencies(GraphNode{node, 'K'});
} }
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
@@ -306,7 +339,6 @@ void CommandEncoder::commit() {
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_); worker_.commit(stream_);
} }
@@ -315,7 +347,6 @@ void CommandEncoder::synchronize() {
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit(); commit();
f.wait(); f.wait();
} }
@@ -333,6 +364,4 @@ CommandEncoder& get_command_encoder(Stream s) {
return device(s.device).get_command_encoder(s); return device(s.device).get_command_encoder(s);
} }
} // namespace cu } // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -8,6 +8,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cudnn.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <unordered_map> #include <unordered_map>
@@ -21,6 +22,7 @@ class CommandEncoder {
~CaptureContext(); ~CaptureContext();
cudaGraph_t graph; cudaGraph_t graph;
CommandEncoder& enc; CommandEncoder& enc;
bool discard{false};
}; };
struct ConcurrentContext { struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc); ConcurrentContext(CommandEncoder& enc);
@@ -65,6 +67,11 @@ class CommandEncoder {
void void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) { void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr()); temporaries_.push_back(arr.data_shared_ptr());
} }
@@ -73,6 +80,10 @@ class CommandEncoder {
void maybe_commit(); void maybe_commit();
void commit(); void commit();
Device& device() {
return device_;
}
CudaStream& stream() { CudaStream& stream() {
return stream_; return stream_;
} }
@@ -137,12 +148,16 @@ class Device {
cublasLtHandle_t lt_handle() const { cublasLtHandle_t lt_handle() const {
return lt_; return lt_;
} }
cudnnHandle_t cudnn_handle() const {
return cudnn_;
}
private: private:
int device_; int device_;
int compute_capability_major_; int compute_capability_major_;
int compute_capability_minor_; int compute_capability_minor_;
cublasLtHandle_t lt_; cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_; std::unordered_map<int, CommandEncoder> encoders_;
}; };

View File

@@ -49,6 +49,20 @@ store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
to[offset] = vec; to[offset] = vec;
} }
// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;
__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}
__host__ __device__ auto operator[](int i) const {
return it[i * stride];
}
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -204,20 +218,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
return cuda::std::make_tuple(a_loc, b_loc, c_loc); return cuda::std::make_tuple(a_loc, b_loc, c_loc);
} }
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t> template <typename IdxT = int64_t>
inline __host__ __device__ IdxT inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem, IdxT elem,
const int* shape, const int* shape,
const int64_t* a_strides, const int64_t* a_strides,
@@ -235,7 +237,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
} }
template <typename IdxT = int64_t> template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d( inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
IdxT elem, IdxT elem,
const int* shape, const int* shape,
const int64_t* a_strides, const int64_t* a_strides,

View File

@@ -19,8 +19,6 @@ void new_stream(Stream s) {
cudaFree(nullptr); cudaFree(nullptr);
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
} }
void eval(array& arr) { void eval(array& arr) {

View File

@@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. buf_ = std::shared_ptr<Buffer>(
Atomic* ac; new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); allocator().free(*ptr);
new (ac) Atomic(0); delete ptr;
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) { });
ptr->~Atomic(); *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
allocator().cuda_free(ptr);
});
} }
void SharedEvent::wait(uint64_t value) { void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value); event_wait(to_atomic(buf_), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void SharedEvent::wait(Stream s, uint64_t value) {
@@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
wait(encoder.stream(), value); wait(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
void SharedEvent::signal(uint64_t value) { void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value); event_signal(to_atomic(buf_), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
@@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
signal(encoder.stream(), value); signal(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value; return to_atomic(buf_)->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load(); return to_atomic(buf_)->load();
} }
} // namespace cu } // namespace cu

View File

@@ -2,6 +2,7 @@
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -55,12 +56,8 @@ class SharedEvent {
bool is_signaled(uint64_t value) const; bool is_signaled(uint64_t value) const;
uint64_t value() const; uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private: private:
std::shared_ptr<Atomic> ac_; std::shared_ptr<mlx::core::allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -0,0 +1,73 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu {
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
a_it.step();
b_it.step();
}
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,206 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,282 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
#include <fmt/format.h>
namespace mlx::core::cu {
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()),
pref_(cublas_preference(device)),
M_(a_rows),
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: Matmul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void Matmul::run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha /* = 1 */,
float beta /* = 0 */) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
out_desc_, // TODO should that be c_desc is it's set?
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,100 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
public:
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha = 1,
float beta = 0);
uint64_t M_;
uint64_t N_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu

View File

@@ -0,0 +1,147 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto g_idx = block.group_index();
auto t_idx = block.thread_index();
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<float>(local_mat.val[j]) *
static_cast<float>(local_vec.val[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
}
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void gemv_batched(
const T* mat,
const T* vec,
T* out,
int rows,
int cols,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides mat_batch_strides,
const __grid_constant__ Strides vec_batch_strides,
int batch_ndim) {
auto block = cg::this_thread_block();
auto batch_idx = block.group_index().y;
auto [vec_offset, mat_offset] = elem_to_loc(
batch_idx,
batch_shape.data(),
vec_batch_strides.data(),
mat_batch_strides.data(),
batch_ndim);
gemv_impl<T, rows_per_block, n_per_thread>(
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 &&
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;
const DataType* vec;
int rows;
int cols = K;
auto mat_strides = const_param(a_batch_strides);
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -1,121 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -1,60 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -105,8 +104,8 @@ __global__ void layer_norm(
T wn[N_READS]; T wn[N_READS];
T bn[N_READS]; T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer; float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i]; xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
@@ -162,7 +161,7 @@ __global__ void layer_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean); cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean; float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i]; float wi = wn[i];
@@ -185,7 +184,7 @@ __global__ void layer_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer; float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i]; float wi = wn[i];

View File

@@ -0,0 +1,146 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <list>
#include <unordered_map>
#include <utility>
namespace mlx::core {
template <
typename K,
typename V,
template <typename...> typename M = std::unordered_map>
class LRUCache {
public:
using value_type = std::pair<K, V>;
using list_type = std::list<value_type>;
using iterator = typename list_type::iterator;
using const_iterator = typename list_type::const_iterator;
using map_type = M<K, iterator>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
size_t size() const {
return map_.size();
}
size_t capacity() const {
return capacity_;
}
bool empty() const {
return vlist_.empty();
}
void resize(size_t new_capacity) {
capacity_ = new_capacity;
trim();
}
iterator begin() {
return vlist_.begin();
}
const_iterator begin() const {
return vlist_.begin();
}
iterator end() {
return vlist_.end();
}
const_iterator end() const {
return vlist_.end();
}
void clear() {
map_.clear();
vlist_.clear();
}
iterator find(const K& key) {
auto it = map_.find(key);
if (it == map_.end())
return end();
vlist_.splice(vlist_.begin(), vlist_, it->second);
return it->second;
}
template <typename U>
std::pair<iterator, bool> emplace(const K& key, U&& value) {
auto it = map_.find(key);
if (it != map_.end()) {
vlist_.splice(vlist_.begin(), vlist_, it->second);
return {it->second, false};
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
trim();
return {vlist_.begin(), true};
}
iterator erase(iterator pos) {
map_.erase(pos->first);
return vlist_.erase(pos);
}
private:
void trim() {
while (map_.size() > capacity_) {
auto last = std::prev(vlist_.end());
map_.erase(last->first);
vlist_.pop_back();
}
}
list_type vlist_;
map_type map_;
size_t capacity_;
};
// Turn a POD struct into a container key by doing bytes compare.
template <typename T>
struct BytesKey {
T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {}
BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
BytesKey(BytesKey&& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
bool operator==(const BytesKey& other) const {
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
}
};
// Compute hash according to the bytes value of T.
template <typename T>
struct BytesHash {
static_assert(std::is_standard_layout_v<T>, "T is not POD");
size_t operator()(const T& pod) const {
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < sizeof(T); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return value;
}
};
template <typename K, typename V>
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
template <typename K, typename V>
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
} // namespace mlx::core

View File

@@ -2,289 +2,15 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: MatMul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
~MatMul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void run(
cu::CommandEncoder& encoder,
void* out,
void* a,
void* b,
void* c = nullptr,
float alpha = 1,
float beta = 0) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order =
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace cu
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
@@ -353,10 +79,25 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1}; batch_shape = {1};
} }
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul(
cu::MatMul matmul(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -371,27 +112,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
encoder.set_input_array(a); if ((batch_count / batch_shape.back()) == 1) {
encoder.set_input_array(b); matmul.run(encoder, out, a, b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return; return;
} }
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); matmul.run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc);
a_it.step();
b_it.step();
}
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -459,7 +186,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::MatMul matmul( cu::Matmul matmul(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -476,41 +203,22 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
encoder.set_input_array(a); if ((batch_count / batch_shape.back()) == 1) {
encoder.set_input_array(b); matmul.run(encoder, out, a, b, c, alpha_, beta_);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return; return;
} }
matmul.run_batched(
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); encoder,
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); out,
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); a,
auto concurrent = encoder.concurrent_context(); b,
for (size_t i = 0; i < nbatch; ++i) { c,
matmul.run( batch_shape,
encoder, a_batch_strides,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, b_batch_strides,
a.data<int8_t>() + a.itemsize() * a_it.loc, c_batch_strides,
b.data<int8_t>() + b.itemsize() * b_it.loc, alpha_,
c.data<int8_t>() + c.itemsize() * c_it.loc, beta_);
alpha_,
beta_);
a_it.step();
b_it.step();
c_it.step();
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
} }
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU(DynamicSlice) NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate) NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT) NO_GPU(FFT)

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -89,7 +88,7 @@ __global__ void rms_norm(
T xn[N_READS]; T xn[N_READS];
T wn[N_READS]; T wn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) { for (int i = 0; i < N_READS; ++i) {
float norm = static_cast<float>(xn[i]) * normalizer; float norm = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm); xn[i] = wn[i] * static_cast<T>(norm);
@@ -132,7 +131,7 @@ __global__ void rms_norm_vjp(
auto index = r * BLOCK_DIM + block.thread_rank(); auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0)); cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]); float t = static_cast<float>(xn[i]);
float wi = wn[i]; float wi = wn[i];
@@ -154,7 +153,7 @@ __global__ void rms_norm_vjp(
T gn[N_READS]; T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size); cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
float xi = xn[i]; float xi = xn[i];
float wi = wn[i]; float wi = wn[i];

View File

@@ -76,7 +76,7 @@ __global__ void ternary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( auto [a_idx, b_idx, c_idx] = elem_to_loc(
index, index,
shape.data(), shape.data(),
a_strides.data(), a_strides.data(),

View File

@@ -3,7 +3,6 @@
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -48,7 +47,7 @@ __global__ void unary_g(
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); IdxT index = cg::this_grid().thread_rank();
if (index < size) { if (index < size) {
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim); auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]); out[index] = Op{}(in[idx]);
} }
} }

View File

@@ -17,6 +17,14 @@ CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
} }
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
void check_cuda_error(const char* name, cudaError_t err) { void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) { if (err != cudaSuccess) {
throw std::runtime_error( throw std::runtime_error(

View File

@@ -4,6 +4,7 @@
#pragma once #pragma once
#include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -33,10 +34,12 @@ class CudaStream {
}; };
// Throw exception if the cuda API does not succeed. // Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err); void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed. // The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types. // Convert Dtype to CUDA C++ types.

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -12,10 +11,10 @@ Worker::Worker()
Worker::~Worker() { Worker::~Worker() {
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(mtx_);
stop_ = true; stop_ = true;
} }
worker_event_.signal(batch_ + 1); cond_.notify_one();
worker_.join(); worker_.join();
} }
@@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void Worker::consume_in_this_thread() { void Worker::signal(void* data) {
for (auto& task : pending_tasks_) { auto w = static_cast<Worker*>(data);
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(w->mtx_);
worker_tasks_[batch_] = std::move(pending_tasks_); w->signaled_batch_++;
} }
uncommited_batches_++; w->cond_.notify_one();
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
} }
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) { // Move pending tasks into tasks
if (pending_tasks_.empty()) {
return; return;
} }
uncommited_batches_ = 0; {
// Signal the |worker_event_| in |signal_stream_| after the kernels in std::lock_guard lock(mtx_);
// |stream_| finish running. // Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_); cudaLaunchHostFunc(signal_stream_, signal, this);
} }
void Worker::thread_fn() { void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) { while (!stop_) {
uint64_t batch = worker_event_.value(); uint64_t current_batch = 0;
Tasks tasks; Tasks tasks;
{ {
std::lock_guard lock(worker_mutex_); std::unique_lock<std::mutex> lk(mtx_);
// Move tasks in signaled batches. cond_.wait(lk, [this, &current_batch] {
auto end = worker_tasks_.upper_bound(batch); return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) { for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) { if (tasks.empty()) {
tasks = std::move(it->second); tasks = std::move(it->second);
@@ -85,7 +72,6 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]); auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1);
} }
} }

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
@@ -24,38 +25,24 @@ class Worker {
// Add a pending |task| that will run when consumed or commited. // Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task); void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream| // Inform worker thread to run current batches after kernels in |stream|
// finish running. // finish running.
void commit(cudaStream_t stream); void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private: private:
void thread_fn(); static void signal(void*);
uint64_t batch_{0}; void thread_fn();
size_t uncommited_batches_{0}; std::mutex mtx_;
std::condition_variable cond_;
uint64_t committed_batch_{0};
uint64_t signaled_batch_{0};
// Cuda stream and event for signaling kernel completion. // Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_; CudaStream signal_stream_;
CudaEvent signal_event_; CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false}; bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to // Tasks are put in |pending_tasks_| first, and then moved to
@@ -63,6 +50,7 @@ class Worker {
using Tasks = std::vector<std::function<void()>>; using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_; Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_; std::map<uint64_t, Tasks> worker_tasks_;
std::thread worker_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache
// try to reclaim memory from the cache
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -= num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); buffer_cache_.release_cached_buffers(mem_required - gc_limit_);

View File

@@ -14,6 +14,10 @@ Event::Event(Stream stream) : stream_(stream) {
auto p = metal::new_scoped_memory_pool(); auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>( event_ = std::shared_ptr<void>(
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
if (event_ == nullptr) {
throw std::runtime_error(
"[Event::Event] Failed to create Metal shared event.");
}
} }
void Event::wait() { void Event::wait() {

View File

@@ -265,9 +265,15 @@ void qvm_split_k(
MTL::Size group_dims = MTL::Size(bk, 2, 1); MTL::Size group_dims = MTL::Size(bk, 2, 1);
MTL::Size grid_dims = MTL::Size(M, N / bn, B); MTL::Size grid_dims = MTL::Size(M, N / bn, B);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape(); auto x_shape = x.shape();
auto x_strides = x.strides(); auto x_strides = x.strides();
if (x_shape.size() == 1) {
x_shape.insert(x_shape.begin(), 1);
x_strides.insert(x_strides.begin(), 0);
}
int x_ndim = x_shape.size();
int x_batch_ndims = x_ndim - 2;
int w_batch_ndims = w.ndim() - 2; int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape(); auto w_shape = w.shape();
auto w_strides = w.strides(); auto w_strides = w.strides();
@@ -278,7 +284,7 @@ void qvm_split_k(
x_shape.insert(x_shape.end() - 2, split_k); x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k; x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D); x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D; x_strides[x_ndim - 1] = split_D;
x_batch_ndims += 1; x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k); w_shape.insert(w_shape.end() - 2, split_k);
@@ -291,6 +297,9 @@ void qvm_split_k(
int final_block_size = K - (split_k - 1) * split_D; int final_block_size = K - (split_k - 1) * split_D;
auto temp_shape = out.shape(); auto temp_shape = out.shape();
if (temp_shape.size() == 1) {
temp_shape.insert(temp_shape.begin(), 1);
}
temp_shape.insert(temp_shape.end() - 2, split_k); temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {}); array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes())); intermediate.set_data(allocator::malloc(intermediate.nbytes()));

View File

@@ -708,7 +708,10 @@ array scaled_dot_product_attention(
} }
if (mask.dtype() == bool_) { if (mask.dtype() == bool_) {
scores = where( scores = where(
mask, scores, array(finfo(scores.dtype()).min, scores.dtype())); mask,
scores,
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
s);
} else { } else {
scores = add(scores, mask, s); scores = add(scores, mask, s);
} }

View File

@@ -1271,19 +1271,6 @@ std::vector<array> Convolution::vjp(
has_neg_padding |= (pd < 0); has_neg_padding |= (pd < 0);
} }
auto padding_lo_ = std::vector<int>(padding_lo);
auto padding_hi_ = std::vector<int>(padding_hi);
// Use negative padding on the gradient output
if (has_neg_padding) {
for (auto& p : padding_lo_) {
p = std::max(0, p);
}
for (auto& p : padding_hi_) {
p = std::max(0, p);
}
}
auto wt_trans = group_transpose(wt, 0, 1, -1); auto wt_trans = group_transpose(wt, 0, 1, -1);
auto grad = conv_general( auto grad = conv_general(
/* const array& input = */ cotan, /* const array& input = */ cotan,
@@ -1305,12 +1292,9 @@ std::vector<array> Convolution::vjp(
for (int i = 0; i < grad.ndim() - 2; i++) { for (int i = 0; i < grad.ndim() - 2; i++) {
if (padding_lo[i] < 0) { if (padding_lo[i] < 0) {
starts[i + 1] -= padding_lo[i]; starts[i + 1] -= padding_lo[i];
padding_lo[i] = 0;
} }
if (padding_hi[i] < 0) { if (padding_hi[i] < 0) {
stops[i + 1] += padding_hi[i]; stops[i + 1] += padding_hi[i];
padding_hi[i] = 0;
} }
} }

View File

@@ -72,7 +72,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Stream events for synchronization after eval // Stream events for synchronization after eval
std::unordered_map<uint32_t, Event> events; std::unordered_map<uint32_t, Event> events;
events.emplace(stream.index, Event{stream}); {
auto e = Event{stream};
e.set_value(1);
synchronizer.attach_event(e);
events.emplace(stream.index, std::move(e));
}
{ {
// Record the degree of each input // Record the degree of each input
@@ -184,21 +189,26 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
} }
std::unordered_set<int> open_streams;
while (!tape.empty()) { while (!tape.empty()) {
auto arr = std::move(tape.back()); auto arr = std::move(tape.back());
tape.pop_back(); tape.pop_back();
auto stream = arr.primitive().stream(); auto stream = arr.primitive().stream();
open_streams.insert(stream.index);
// Lookup corresponding event if (async) {
auto e = events.find(stream.index); // Lookup corresponding event
if (e == events.end()) { auto e = events.find(stream.index);
e = events.emplace(stream.index, Event{stream}).first; if (e == events.end()) {
} e = events.emplace(stream.index, Event{stream}).first;
e->second.set_value(1); }
arr.attach_event(e->second); e->second.set_value(1);
for (auto& s : arr.siblings()) { arr.attach_event(e->second);
s.attach_event(e->second); for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
} }
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
@@ -227,9 +237,10 @@ array eval_impl(std::vector<array> outputs, bool async) {
(get_active_memory() > get_memory_limit() && (get_active_memory() > get_memory_limit() &&
scheduler::n_active_tasks() > 0)) { scheduler::n_active_tasks() > 0)) {
// Commit any open streams // Commit any open streams
for (auto& [_, e] : events) { for (auto i : open_streams) {
if (e.stream().device == Device::gpu) { auto s = get_stream(i);
gpu::finalize(e.stream()); if (s.device == Device::gpu) {
gpu::finalize(s);
} }
} }
scheduler::wait_for_one(); scheduler::wait_for_one();
@@ -263,9 +274,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
// Signal the event in its stream // Signal the event in its stream
for (auto& [_, e] : events) { for (auto i : open_streams) {
auto s = e.stream(); auto s = get_stream(i);
e.signal(s); if (auto e = events.find(i); e != events.end()) {
e->second.signal(s);
}
if (s.device == Device::gpu) { if (s.device == Device::gpu) {
gpu::finalize(s); gpu::finalize(s);
} }
@@ -302,7 +315,7 @@ void eval(std::vector<array> outputs) {
return; return;
} }
eval_impl(std::move(outputs), false).event().wait(); eval_impl(std::move(outputs), false).wait();
} }
std::pair<std::vector<array>, std::vector<array>> vjp( std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@@ -477,7 +477,7 @@ class Adam(Optimizer):
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
Args: Args:
learning_rate (float or callable): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
@@ -546,7 +546,7 @@ class AdamW(Adam):
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t) w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)
Args: Args:
learning_rate (float or callable): The learning rate :math:`\alpha`. learning_rate (float or callable): The learning rate :math:`\alpha`.

View File

@@ -1,9 +1,10 @@
#!/bin/bash #!/bin/bash
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_39_x86_64 \ --plat manylinux_2_35_x86_64 \
--exclude libcublas* \ --exclude libcublas* \
--exclude libnvrtc* \ --exclude libnvrtc* \
--exclude libcuda* \
-w wheel_tmp -w wheel_tmp
@@ -15,7 +16,7 @@ rm "${repaired_wheel}"
mlx_so="mlx/lib/libmlx.so" mlx_so="mlx/lib/libmlx.so"
rpath=$(patchelf --print-rpath "${mlx_so}") rpath=$(patchelf --print-rpath "${mlx_so}")
base="\$ORIGIN/../../nvidia" base="\$ORIGIN/../../nvidia"
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
python ../python/scripts/repair_record.py ${mlx_so} python ../python/scripts/repair_record.py ${mlx_so}

View File

@@ -2,6 +2,7 @@
auditwheel repair dist/* \ auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \ --plat manylinux_2_35_x86_64 \
--only-plat \
--exclude libmlx* \ --exclude libmlx* \
-w wheel_tmp -w wheel_tmp

View File

@@ -4022,8 +4022,9 @@ void init_ops(nb::module_& m) {
Args: Args:
file (file, str): File in which the array is saved. file (file, str): File in which the array is saved.
arrays (dict(str, array)): The dictionary of names to arrays to arrays (dict(str, array)): The dictionary of names to arrays to
be saved. metadata (dict(str, str), optional): The dictionary of be saved.
metadata to be saved. metadata (dict(str, str), optional): The dictionary of
metadata to be saved.
)pbdoc"); )pbdoc");
m.def( m.def(
"save_gguf", "save_gguf",
@@ -4258,7 +4259,7 @@ void init_ops(nb::module_& m) {
.. math:: .. math::
w_i = s \hat{w_i} - \beta w_i = s \hat{w_i} + \beta
Args: Args:
w (array): Matrix to be quantized w (array): Matrix to be quantized

View File

@@ -15,19 +15,12 @@ cuda_skip = {
"TestOps.test_hadamard_grad_vmap", "TestOps.test_hadamard_grad_vmap",
# Convolutions NYI # Convolutions NYI
"TestConv.test_1d_conv_with_2d", "TestConv.test_1d_conv_with_2d",
"TestConv.test_asymmetric_padding",
"TestConv.test_basic_grad_shapes",
"TestConv.test_conv2d_unaligned_channels",
"TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_1d_groups_flipped",
"TestConv.test_conv_general_flip_grad", "TestConv.test_conv_general_flip_grad",
"TestConv.test_conv_groups_grad", "TestConv.test_conv_groups_grad",
"TestConv.test_numpy_conv",
"TestConv.test_repeated_conv",
"TestConv.test_torch_conv_1D",
"TestConv.test_torch_conv_1D_grad", "TestConv.test_torch_conv_1D_grad",
"TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_2D",
"TestConv.test_torch_conv_2D_grad", "TestConv.test_torch_conv_2D_grad",
"TestConv.test_torch_conv_3D",
"TestConv.test_torch_conv_3D_grad", "TestConv.test_torch_conv_3D_grad",
"TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_depthwise",
"TestConv.test_torch_conv_general", "TestConv.test_torch_conv_general",
@@ -40,10 +33,6 @@ cuda_skip = {
"TestConvTranspose.test_torch_conv_transpose_3D", "TestConvTranspose.test_torch_conv_transpose_3D",
"TestConvTranspose.test_torch_conv_transpose_3D_grad", "TestConvTranspose.test_torch_conv_transpose_3D_grad",
"TestConvTranspose.test_torch_conv_transpose_3d_output_padding", "TestConvTranspose.test_torch_conv_transpose_3d_output_padding",
"TestExportImport.test_export_conv",
"TestLayers.test_conv1d",
"TestLayers.test_conv2d",
"TestVmap.test_vmap_conv",
# FFTs NYI # FFTs NYI
"TestFFT.test_fft", "TestFFT.test_fft",
"TestFFT.test_fft_big_powers_of_two", "TestFFT.test_fft_big_powers_of_two",

View File

@@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fully_masked(self):
Lkv = 8
mask = mx.array(False)
for D in [4, 128]:
for Lq in [1, 8]:
q = mx.random.normal(shape=(1, 4, Lq, D))
k = mx.random.normal(shape=(1, 4, Lkv, D))
v = mx.random.normal(shape=(1, 4, Lkv, D))
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
self.assertTrue(mx.all(mx.isnan(out)))
def test_fast_sdpa_few_query(self): def test_fast_sdpa_few_query(self):
D = 64 D = 64
L = 43 L = 43

View File

@@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3) self.assertLess((y_q - y_hat).abs().max(), 2e-3)
# Test with 1D vector
group_size = 32
bits = 8
N = 2048
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self): def test_throw(self):
x = mx.random.normal(shape=(10, 512)) x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512)) w = mx.random.normal(shape=(32, 512))

View File

@@ -9,7 +9,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run from subprocess import run
from setuptools import Command, Extension, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
@@ -166,6 +166,10 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name # Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"]) subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel): class MLXBdistWheel(bdist_wheel):
@@ -184,19 +188,23 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
if __name__ == "__main__": if __name__ == "__main__":
package_dir = {"": "python"} package_dir = {"": "python"}
packages = [ packages = find_namespace_packages(
"mlx", where="python",
"mlx.nn", exclude=[
"mlx.nn.layers", "src",
"mlx.optimizers", "tests",
] "scripts",
"mlx.lib",
"mlx.include",
"mlx.share",
"mlx.share.**",
"mlx.include.**",
],
)
build_macos = platform.system() == "Darwin" build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
install_requires = []
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version() version = get_version()
_setup = partial( _setup = partial(
@@ -221,7 +229,7 @@ if __name__ == "__main__":
}, },
) )
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx.core": ["*.pyi"]}
extras = { extras = {
"dev": [ "dev": [
@@ -239,6 +247,7 @@ if __name__ == "__main__":
"mlx.distributed_config = mlx.distributed_run:distributed_config", "mlx.distributed_config = mlx.distributed_run:distributed_config",
] ]
} }
install_requires = []
# Release builds for PyPi are in two stages. # Release builds for PyPi are in two stages.
# Each stage should be run from a clean build: # Each stage should be run from a clean build:
@@ -258,11 +267,11 @@ if __name__ == "__main__":
# - Package name is back-end specific, e.g mlx-metal # - Package name is back-end specific, e.g mlx-metal
if build_stage != 2: if build_stage != 2:
if build_stage == 1: if build_stage == 1:
if build_macos: install_requires.append(
install_requires += [f"mlx-metal=={version}"] f'mlx-metal=={version}; platform_system == "Darwin"'
else: )
extras["cuda"] = [f"mlx-cuda=={version}"] extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f"mlx-cpu=={version}"] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
_setup( _setup(
name="mlx", name="mlx",
@@ -277,9 +286,15 @@ if __name__ == "__main__":
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" name = "mlx-cuda"
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
"nvidia-cudnn-cu12==12.9.*",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(
name=name, name=name,
packages=["mlx"], packages=["mlx"],
install_requires=install_requires,
) )