Compare commits

...

59 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4987e7615a Improve the cutlass gemm 2025-08-25 18:18:19 -07:00
Angelos Katharopoulos
e1303f6160 Reset cutlass gemm to working state again 2025-08-21 01:29:43 -07:00
Angelos Katharopoulos
cf5eef095d tmp 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
395d582719 Add a cutlass gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
05583bcd10 More pipelining for the sm_80 gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
6fce01593a Improve gemm 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
97afe40b7b Remove duplicate register tile 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
f70c62d69c Simple gemm example 2025-08-20 23:51:25 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205 Fix overflow in large filter small channels (#2520) 2025-08-20 08:03:29 -07:00
russellizadi
512281781c Remove state return from function example in compile documentation (#2518) 2025-08-20 00:45:05 -07:00
Cheng
ac85ddfdb7 [CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv

* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Cheng
65d0d40232 Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class

* Implement forward rms_norm with cuDNN

* Revert back to old rms norm kernel
2025-08-20 09:29:28 +09:00
Awni Hannun
cea9369610 fix lapack svd (#2515) 2025-08-18 15:07:59 -07:00
Awni Hannun
e7c6e1db82 no segfault with uninitialized array.at (#2514) 2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b fix custom kernel test (#2510) 2025-08-18 06:45:59 -07:00
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
Awni Hannun
6441c21a94 Faster general unary op (#2472)
* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
2025-08-15 15:04:12 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
Awni Hannun
fa89f0b150 faster gather qmm sorted test (#2463) 2025-08-05 06:27:40 -07:00
Awni Hannun
ca973d1e83 fix install tags (#2464) 2025-08-04 20:01:23 -07:00
Cheng
828c5f1137 Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
2025-08-05 09:41:03 +09:00
Gaétan Lepage
7d86a5c108 Feat: add USE_SYSTEM_FMT CMake option (#2219) 2025-08-04 16:36:11 -07:00
Awni Hannun
0b807893a7 fix wraps compile (#2461) 2025-08-04 16:14:18 -07:00
Awni Hannun
6ad0889c8a default install cuda on linux (#2462) 2025-08-04 15:33:05 -07:00
Zamderax
737dd6d1ac Add missing <algorithm> header to jit_compiler.cpp (#2460)
Fixes compilation error on Linux where std::find_if is used on line 121
but the <algorithm> header was not included. While this might work on
some platforms due to transitive includes, it's not guaranteed by the
C++ standard.

Resolves issue #2459
2025-08-04 14:00:46 -07:00
Cheng
aaf78f4c6b Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph

* Remove unused destructor
2025-08-02 21:28:57 +09:00
Angelos Katharopoulos
8831064493 Fix arctan2 grads (#2453) 2025-08-01 21:06:04 -07:00
Angelos Katharopoulos
be9bc96da4 [CUDA] Matmul utils initial commit (#2441) 2025-08-01 14:22:25 -07:00
Angelos Katharopoulos
86258f292f [CUDA] Vectorize generated kernels (#2444) 2025-07-31 18:18:57 -07:00
Cheng
b26d88591c [CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
2025-08-01 10:16:06 +09:00
Cheng
86c6a15571 [CUDA] Backward convolution (#2431) 2025-08-01 09:54:05 +09:00
junpeiz
8b25ce62d5 Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export.

* Skip quantization related test for CUDA backend.
2025-07-31 11:06:26 -07:00
Awni Hannun
da5912e4f2 fix custom metal extension (#2446) 2025-07-31 06:25:36 -07:00
Cheng
daafee676f Fix wrong graph key when using concurrent context (#2447) 2025-07-31 06:01:05 -07:00
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
Cheng
2204182bba Make CI faster (#2440) 2025-07-30 02:26:36 -07:00
Cheng
3628e5d497 Use load_vector in arg_reduce (#2439) 2025-07-30 17:40:26 +09:00
198 changed files with 8239 additions and 1876 deletions

View File

@@ -81,23 +81,24 @@ jobs:
export DEBIAN_FRONTEND=noninteractive export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a export NEEDRESTART_MODE=a
sudo apt-get update sudo apt-get update
sudo apt-get upgrade -y
pip install --upgrade cmake
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
pip install -e ".[dev]" uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
echo "stubs" uv pip install typing_extensions
pip install typing_extensions uv run --no-project setup.py generate_stubs
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source .venv/bin/activate
python -m unittest discover python/tests -v python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
@@ -105,6 +106,7 @@ jobs:
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source .venv/bin/activate
mkdir -p build && cd build mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc` make -j `nproc`
@@ -130,33 +132,30 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.9 HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi brew install openmpi uv
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v uv pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate uv pip install typing_extensions
pip install typing_extensions uv run --no-project setup.py generate_stubs
python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source env/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
@@ -165,16 +164,17 @@ jobs:
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
source env/bin/activate source .venv/bin/activate
cd examples/extensions cd examples/extensions
pip install -r requirements.txt uv pip install -r requirements.txt
python setup.py build_ext -j8 uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu` mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run CPP tests name: Run CPP tests
@@ -183,7 +183,7 @@ jobs:
- run: - run:
name: Build small binary name: Build small binary
command: | command: |
source env/bin/activate source .venv/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
@@ -195,12 +195,13 @@ jobs:
- run: - run:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v uv pip install -e .
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test: cuda_build_and_test:
parameters: parameters:
@@ -224,17 +225,17 @@ jobs:
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf - curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
python3 -m venv env uv venv
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]" -v uv pip install -e ".[dev]" -v
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |
source env/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run: - run:
@@ -343,14 +344,10 @@ jobs:
export DEBIAN_FRONTEND=noninteractive export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a export NEEDRESTART_MODE=a
sudo apt-get update sudo apt-get update
sudo apt-get upgrade -y
TZ=Etc/UTC sudo apt-get -y install tzdata TZ=Etc/UTC sudo apt-get -y install tzdata
sudo apt-get install -y apt-utils
sudo apt-get install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env $PYTHON -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip

View File

@@ -43,6 +43,7 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON) option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
@@ -242,12 +243,16 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library. # Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare( if(USE_SYSTEM_FMT)
fmt find_package(fmt REQUIRED)
GIT_REPOSITORY https://github.com/fmtlib/fmt.git else()
GIT_TAG 10.2.1 FetchContent_Declare(
EXCLUDE_FROM_ALL) fmt
FetchContent_MakeAvailable(fmt) GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -78,13 +78,13 @@ pip install mlx
To install the CUDA backend on Linux, run: To install the CUDA backend on Linux, run:
```bash ```bash
pip install "mlx[cuda]" pip install mlx[cuda]
``` ```
To install a CPU-only Linux package, run: To install a CPU-only Linux package, run:
```bash ```bash
pip install "mlx[cpu]" pip install mlx[cpu]
``` ```
Checkout the Checkout the

View File

@@ -1,4 +1,5 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
sphinx-copybutton
mlx mlx

View File

@@ -18,6 +18,7 @@ release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
extensions = [ extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc", "sphinx.ext.autodoc",
"sphinx.ext.autosummary", "sphinx.ext.autosummary",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",

View File

@@ -128,6 +128,7 @@ relying on a copy from ``ensure_row_contiguous``:
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source source=source
ensure_row_contiguous=False,
) )
def exp_elementwise(a: mx.array): def exp_elementwise(a: mx.array):
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes=[a.shape], output_shapes=[a.shape],
output_dtypes=[a.dtype], output_dtypes=[a.dtype],
ensure_row_contiguous=False,
) )
return outputs[0] return outputs[0]

View File

@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::stream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname = "axpby_general_" + type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/cuda
python/memory_management python/memory_management
python/nn python/nn
python/optimizers python/optimizers

View File

@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install "mlx[cuda]" pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
@@ -49,7 +49,7 @@ For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell .. code-block:: shell
pip install "mlx[cpu]" pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following To install the CPU-only package from PyPi your system must meet the following
requirements: requirements:

9
docs/src/python/cuda.rst Normal file
View File

@@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,3 +13,4 @@ Fast
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel metal_kernel
cuda_kernel

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads) optimizer.update(model, grads)
# Save the state # Save the state
state = tree_flatten(optimizer.state) state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", dict(state)) mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint, # Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state # recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2) optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y): def fun(x, y):
z = x + y z = x + y
state.append(z) state.append(z)
return mx.exp(z), state return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0)) fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)] # Prints [array(3, dtype=float32)]

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++). front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples. This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation To see the full list of functions check-out the :ref:`API documentation
<export>`. <export>`.
Basics of Exporting Basics of Exporting
------------------- -------------------
Let's start with a simple example: Let's start with a simple example:
.. code-block:: python .. code-block:: python
def fun(x, y): def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0) x = mx.array(1.0)
y = mx.array(1.0) y = mx.array(1.0)
# Both arguments to fun are positional # Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y) mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs. the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters. :obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters # Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items()))) model.update(tree_unflatten(list(params.items())))
return model(x) return model(x)
params = dict(tree_flatten(model.parameters())) params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array(-1.0))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None): def fun(x, y=None):
constant = mx.array(3.0) constant = mx.array(3.0)
if y is not None: if y is not None:
x += y x += y
return x + constant return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter: with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out) print(out)
In the above example the function constant data, (i.e. ``constant``), is only In the above example the function constant data, (i.e. ``constant``), is only
saved once. saved once.
Transformations with Imported Functions Transformations with Imported Functions
--------------------------------------- ---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32) # Prints: array(1, dtype=float32)
print(dfdx(x)) print(dfdx(x))
# Compile the imported function # Compile the imported function
mx.compile(imported_fun) mx.compile(imported_fun)
# Prints: array(0, dtype=float32) # Prints: array(0, dtype=float32)
print(compiled_fun(x)[0]) print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32) // Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl; std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string, ``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++. mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@@ -16,6 +17,19 @@
namespace my_ext { namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname; std::string kname = "axpby_";
kname << "axpby_"; kname += (contiguous_kernel ? "contiguous_" : "general_");
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname += type_to_name(out);
kname << type_to_name(out);
// Load the metal library // Load the metal library
auto lib = d.get_library("mlx_ext"); auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib); auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.2.0 nanobind==2.4.0

View File

@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4)) a = mx.ones((3, 4))
b = mx.ones((3, 4)) b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c.shape}") print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c_cpu.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -10,6 +10,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/dtype.h" #include "mlx/dtype.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core { namespace mlx::core {
@@ -18,8 +19,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>; using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t; using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>; using Shape = SmallVector<ShapeElem>;
using Strides = std::vector<int64_t>; using Strides = SmallVector<int64_t>;
class array { class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc /* An array is really a node in a graph. It contains a shared ArrayDesc

View File

@@ -197,7 +197,7 @@ void shared_buffer_reshape(
array& out); array& out);
template <typename T> template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) { inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index)); vec.erase(std::next(vec.begin(), index));
return vec; return vec;
} }

View File

@@ -157,10 +157,12 @@ inline void build_kernel(
#endif #endif
// Start the kernel // Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl; os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list // Skip constants from the input list
if (is_constant(i)) { if (is_constant(i)) {
@@ -175,8 +177,8 @@ inline void build_kernel(
<< "];" << std::endl; << "];" << std::endl;
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ os << " const int64_t* " << xname << "_strides = strides["
<< "];" << std::endl; << strides_index++ << "];" << std::endl;
} }
} }
@@ -186,10 +188,8 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl; << "*)args[" << cnt++ << "];" << std::endl;
} }
// Add output strides and shape to extract the indices. // Add output size
if (!contiguous) { if (contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
} }
@@ -290,7 +290,6 @@ void Compiled::eval_cpu(
// Collect function input arguments. // Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) { if (is_constant_(i)) {
continue; continue;
@@ -298,9 +297,6 @@ void Compiled::eval_cpu(
const auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
} }
// Get the kernel name from the lib // Get the kernel name from the lib
@@ -335,16 +331,20 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
if (!contiguous) { if (contiguous) {
args.push_back((void*)shape.data());
} else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun, encoder.dispatch([fun,
args = std::move(args), args = std::move(args),
strides = std::move(strides), strides = std::move(strides),
shape = std::move(shape)]() mutable { fun(args.data()); }); shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream> #include <sstream>
#include <vector> #include <vector>

View File

@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx) INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri) INSTANTIATE_LAPACK_REAL(trtri)

View File

@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) { switch (in.dtype()) {
case bool_: case bool_:
case uint8: case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8: case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break; break;
case int64: case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_); reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break; break;
case float16: case float16:

View File

@@ -8,7 +8,7 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -333,47 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output // Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
? CopyType::Vector ? CopyType::Vector
: CopyType::General; : CopyType::General;
copy_cpu(in, out, ctype, stream()); copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch( encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable { dispatch_all_types(out.dtype(), [&](auto type_tag) {
switch (out.dtype()) { sort<MLX_GET_TYPE(type_tag)>(out, axis);
case bool_: });
return sort<bool>(out, axis_); });
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
} }
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -81,9 +81,7 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack). // Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M; const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N"; auto jobz = (u_ptr) ? "A" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned. // Will contain the number of singular values after the call has returned.
int ns = 0; int ns = 0;
@@ -91,30 +89,20 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info; int info;
// Compute workspace size. // Compute workspace size.
gesvdx<T>( gesdd<T>(
/* jobu = */ job_u, /* jobz = */ jobz,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ nullptr, /* a = */ nullptr,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr, /* s = */ nullptr,
/* u = */ nullptr, /* u = */ nullptr,
/* ldu = */ &ldu, /* ldu = */ &ldu,
@@ -136,20 +124,13 @@ void svd_impl(
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
gesvdx<T>( gesdd<T>(
/* jobu = */ job_u, /* jobz = */ jobz,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
/* m = */ &N, /* m = */ &N,
/* n = */ &M, /* n = */ &M,
/* a = */ in_ptr + M * N * i, /* a = */ in_ptr + M * N * i,
/* lda = */ &lda, /* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i, /* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U. // According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -167,13 +148,6 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info; ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);

View File

@@ -8,7 +8,6 @@ target_sources(
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -17,12 +16,18 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cutlass_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simple_gemm.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -39,22 +44,26 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else() else()
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif() endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -81,6 +90,9 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Keep ptx around for inspection
target_compile_options(mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--keep>")
# Enable calling host constexpr functions from device. This is needed because # Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only. # the constexpr version of isnan is host only.
target_compile_options( target_compile_options(
@@ -146,7 +158,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare( FetchContent_Declare(
cudnn cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1 GIT_TAG v1.14.0
GIT_SHALLOW TRUE GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
@@ -166,3 +178,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Fetch and make available cutlass
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v4.1.0)
FetchContent_Populate(cutlass)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)

View File

@@ -44,8 +44,11 @@ struct ArgMin {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] < best.val) { if (vals[i] < best.val) {
best.val = vals[i]; best.val = vals[i];
@@ -74,8 +77,11 @@ struct ArgMax {
} }
template <int N> template <int N>
__device__ IndexValPair<T> __device__ IndexValPair<T> reduce_many(
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) { IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (vals[i] > best.val) { if (vals[i] > best.val) {
best.val = vals[i]; best.val = vals[i];
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim); int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim); int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op; Op op;
T init = op.init(); T init = op.init();
IndexValPair<T> best{0, init}; IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x; auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked( auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS); best = op.reduce_many(best, vals, tid * N_READS);
} }
@@ -166,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
num_blocks, num_blocks,
block_dim(), block_dim(),
0,
in.data<T>(), in.data<T>(),
out.data<uint32_t>(), out.data<uint32_t>(),
out.size(), out.size(),

View File

@@ -0,0 +1,21 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -99,39 +99,89 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int NDIM> template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd( __global__ void binary_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides, const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>( IdxT index_rest =
index, shape.data(), a_strides.data(), b_strides.data()); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
out[index] = Op{}(a[a_idx], b[b_idx]); if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_g( __global__ void binary_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides, const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides, const __grid_constant__ Strides b_strides,
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [a_idx, b_idx] = elem_to_loc( IdxT index_rest =
index, shape.data(), a_strides.data(), b_strides.data(), ndim); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
out[index] = Op{}(a[a_idx], b[b_idx]); if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
@@ -209,37 +259,61 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = auto kernel = cu::binary_g_nd<
get_launch_args(out, large()); Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::binary_g_nd< kernel,
Op, {num_blocks_x, num_blocks_y},
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.size(), rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto [num_blocks, block_dims] = get_launch_args(out, large()); auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::binary_g<Op, InType, OutType, IdxT>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.size(), rest,
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
const_param(b_strides), const_param(b_strides),
@@ -264,6 +338,7 @@ void binary_op_gpu_inplace(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
@@ -301,54 +376,4 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \ binary_op_gpu<cu::func>(inputs, out, name(), s); \
} }
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Power)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Remainder)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Subtract)
} // namespace mlx::core

View File

@@ -127,45 +127,99 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
} }
} }
template <typename Op, typename In, typename Out, typename IdxT, int NDIM> template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_two_g_nd( __global__ void binary_two_g_nd(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
Out* out_b, Out* out_b,
IdxT size, IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides, const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) { const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>( IdxT index_rest =
index, shape.data(), a_strides.data(), b_strides.data()); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
auto out = Op{}(a[a_idx], b[b_idx]); if (index_rest >= size_rest) {
out_a[index] = out[0]; return;
out_b[index] = out[1];
} }
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
} }
template <typename Op, typename In, typename Out, typename IdxT> template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_two_g( __global__ void binary_two_g(
const In* a, const In* a,
const In* b, const In* b,
Out* out_a, Out* out_a,
Out* out_b, Out* out_b,
IdxT size, IdxT size_rest,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides, const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides, const __grid_constant__ Strides b_strides,
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [a_idx, b_idx] = elem_to_loc( IdxT index_rest =
index, shape.data(), a_strides.data(), b_strides.data(), ndim); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
auto out = Op{}(a[a_idx], b[b_idx]); if (index_rest >= size_rest) {
out_a[index] = out[0]; return;
out_b[index] = out[1];
} }
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
} }
template <typename Op, typename In, typename Out> template <typename Op, typename In, typename Out>
@@ -225,40 +279,64 @@ void binary_two_op_gpu_inplace(
auto& a_strides = strides[0]; auto& a_strides = strides[0];
auto& b_strides = strides[1]; auto& b_strides = strides[1];
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out_a.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = auto kernel = cu::binary_two_g_nd<
get_launch_args(out_a, large()); Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::binary_two_g_nd< kernel,
Op, {num_blocks_x, num_blocks_y},
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
out_a.size(), rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides)); const_param<dims_constant()>(b_strides));
}); });
} else { } else {
auto [num_blocks, block_dims] = auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
get_launch_args(out_a, large()); if (work_per_thread == 4) {
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::binary_two_g<Op, InType, OutType, IdxT>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
out_b.data<OutType>(), out_b.data<OutType>(),
out_a.size(), rest,
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
const_param(b_strides), const_param(b_strides),
@@ -287,6 +365,7 @@ void binary_two_op_gpu_inplace(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),

View File

@@ -104,10 +104,41 @@ struct FusedKernelBuilder {
" }\n"; " }\n";
} }
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop // Work loop
os += if (!contiguous) {
"\n" os +=
" for (int i = 0; i < work_per_thread && index < size; i++) {\n"; "\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
// Read inputs. // Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
@@ -122,7 +153,7 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname); value = fmt::format("{}[0]", xname);
} else if (contiguous) { } else if (contiguous) {
value = fmt::format("{}[index]", xname); value = fmt::format("vec_{}[i]", xname);
} else { } else {
value = fmt::format("{}[{}_idx]", xname, xname); value = fmt::format("{}[{}_idx]", xname, xname);
} }
@@ -150,25 +181,30 @@ struct FusedKernelBuilder {
// Write output. // Write output.
for (const auto& x : outputs) { for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x)); os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
} }
// End of work loop // End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) { if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i]; const auto& x = inputs[i];
const std::string& xname = namer.get_name(x); const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) { if (is_scalar(x) || is_constant(i)) {
continue; continue;
} }
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n"; os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
} }
} }
os += " }\n"; os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
}
os += "}\n"; os += "}\n";
} }
}; };
@@ -192,6 +228,15 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu"); nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream(); auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code. // Build source code.
cu::FusedKernelBuilder builder{ cu::FusedKernelBuilder builder{
@@ -205,29 +250,25 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n"; builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names. // Build kernel names.
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) { kernel_names.push_back(fmt::format(
kernel_names.push_back(fmt::format( "mlx::core::cu::{}_contiguous<uint32_t, {}>",
"mlx::core::cu::{}_contiguous<uint32_t, {}>", lib_name(),
lib_name(), work_per_thread));
work_per_thread)); kernel_names.push_back(fmt::format(
kernel_names.push_back(fmt::format( "mlx::core::cu::{}_contiguous<int64_t, {}>",
"mlx::core::cu::{}_contiguous<int64_t, {}>", lib_name(),
lib_name(), work_per_thread));
work_per_thread)); for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (int i = 1; i <= MAX_NDIM; ++i) { for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", "mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format( kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>", "mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
lib_name(),
i,
work_per_thread));
} }
} }
return std::make_pair(std::move(builder.os), std::move(kernel_names));
return std::make_tuple(
false, std::move(builder.os), std::move(kernel_names));
}); });
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
@@ -269,7 +310,6 @@ void Compiled::eval_gpu(
} }
// Choose work per thread // Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) { if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1; work_per_thread = 1;
} }
@@ -295,7 +335,7 @@ void Compiled::eval_gpu(
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,257 +1,214 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
// Not all engines support it so can not use this API now. // Alias for better readability.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 #define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
struct ConvCacheKey { struct ConvCacheKey {
int device_id; int device_id;
cudnnBackendDescriptorType_t backend_type; cudnnDataType_t cudnn_dtype;
cudnnDataType_t cudnn_type;
std::array<int, MAX_NDIM> input_shape; std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape; std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo; std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi; std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation; std::array<int, MAX_NDIM> dilation;
int groups; int groups;
bool flip;
uint8_t input_alignment; uint8_t input_alignment;
uint8_t filter_alignment; uint8_t weight_alignment;
uint8_t output_alignment; uint8_t output_alignment;
}; };
auto& conv_cache() { auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache( static LRUBytesKeyCache<
/* capacity */ 128); ConvCacheKey,
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
return cache; return cache;
} }
template <typename T, typename U> auto get_conv_op_settings(
inline std::vector<T> convert_vector(const std::vector<U>& vec) { cudnnBackendDescriptorType_t backend_type,
return std::vector<T>(vec.begin(), vec.end()); array& x,
} array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
template <typename T> if (backend_type == CONV_BACKWARD_INPUT) {
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) { for (int i = 0; i < padding_lo.size(); ++i) {
if (vec.size() > MAX_NDIM) { int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
throw std::runtime_error( padding_lo[i] = wt_size - padding_lo[i] - 1;
fmt::format("ndim can not be larger than {}.", MAX_NDIM)); int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
} int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
std::array<T, MAX_NDIM> result = {}; padding_hi[i] = out_size - in_size + padding_hi[i];
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
} }
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} }
return alignment;
} }
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) { std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
auto [shape, strides] = nhwc_to_nchw(x); cu::CommandEncoder& encoder,
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type, cudnnBackendDescriptorType_t backend_type,
Dtype dtype, Dtype dtype,
cudnn_frontend::OperationGraph& op_graph, array& x,
bool use_fallback = false) { array& w,
cudnn_frontend::GeneratorSource source; array& y,
if (use_fallback) { const SmallVector<int64_t>& stride,
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) { const SmallVector<int64_t>& padding_lo,
auto fallback = cudnn_frontend::EngineFallbackListBuilder() const SmallVector<int64_t>& padding_hi,
.setOperationGraph(op_graph) const SmallVector<int64_t>& dilation) {
.setOperation(backend_type) try {
.build(); auto compute_dtype = (dtype == float16 || dtype == bfloat16)
return fallback.getFallbackList(); ? CUDNN_DATA_FLOAT
}; : dtype_to_cudnn_type(dtype);
} else { auto conv_desc = cudnn_frontend::ConvDescBuilder()
source = [](cudnn_frontend::OperationGraph& op_graph) { .setDataType(compute_dtype)
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() .setMathMode(CUDNN_CROSS_CORRELATION)
.setOperationGraph(op_graph) .setNDims(stride.size())
.setHeurMode(CUDNN_HEUR_MODE_A) .setStrides(stride.size(), stride.data())
.build(); .setPrePadding(padding_lo.size(), padding_lo.data())
return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); .setPostPadding(padding_hi.size(), padding_hi.data())
}; .setDilation(dilation.size(), dilation.data())
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
const array& in,
const array& wt,
array& out) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()),
const_cast<void*>(wt.data<void>()),
out.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build(); .build();
auto handle = encoder.device().cudnn_handle(); auto op = cudnn_frontend::OperationBuilder(backend_type)
cudnnSetStream(handle, encoder.stream()); .setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
cudaGraph_t graph; return cudnn_frontend::OperationGraphBuilder()
cudaGraphCreate(&graph, 0); .setHandle(encoder.device().cudnn_handle())
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer( .setOperationGraph(ops.size(), ops.data())
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); .build();
if (cudnnBackendPopulateCudaGraph( } catch (cudnn_frontend::cudnnException& error) {
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
CUDNN_STATUS_SUCCESS) { throw;
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key,
const std::string& op_graph_tag,
const array& in,
const array& wt,
array& out) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, in, wt, out)) {
conv_cache().emplace(cache_key, std::move(plan));
return true;
}
} catch (cudnn_frontend::cudnnException&) {
} }
return std::nullopt;
} }
return false;
} }
} // namespace // Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) { const array& x,
nvtx3::scoped_range r("Convolution::eval_gpu"); int groups,
if (out.size() == 0) { int group_dim,
return; int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
} }
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
assert(inputs.size() == 2); // Do necessary transposes and copies to prepare the inputs and outputs for
array in = inputs[0]; // building the cuDNN conv op. It is safe to be called multiple times in one
array wt = inputs[1]; // eval_gpu, with cost of possible redundant copies.
out.set_data(allocator::malloc(out.nbytes())); std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
auto& s = stream(); cudnnBackendDescriptorType_t backend_type,
auto& encoder = cu::get_command_encoder(s); array in,
array wt,
array out,
int groups,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = group_transpose(wt, groups, 0, 0, -1, s);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = group_transpose(in, groups, -1, 0, -1, s);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
}
// cuDNN requires contiguous input. // cuDNN requires contiguous input.
// TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s); in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in); encoder.add_temporary(in);
@@ -261,80 +218,201 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporary(wt); encoder.add_temporary(wt);
} }
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(final_out);
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; if (backend_type == CONV_BACKWARD_WEIGHT) {
auto cudnn_type = dtype_to_cudnn_type(in.dtype()); // Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache. // Search cache.
ConvCacheKey cache_key{ ConvCacheKey cache_key{
encoder.device().cuda_device(), encoder.device().cuda_device(),
backend_type, dtype_to_cudnn_type(dtype),
cudnn_type, vector_key(in.shape()),
fixed_vector(in.shape()), vector_key(wt.shape()),
fixed_vector(wt.shape()), vector_key(kernel_strides_),
fixed_vector(padding_lo_), vector_key(padding_lo_),
fixed_vector(padding_hi_), vector_key(padding_hi_),
fixed_vector(kernel_strides_), vector_key(kernel_dilation_),
fixed_vector(kernel_dilation_),
groups_, groups_,
flip_,
get_alignment(in), get_alignment(in),
get_alignment(wt), get_alignment(wt),
get_alignment(out)}; get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) { auto& [backend_type, plan] = it->second;
throw std::runtime_error("Cached convolution plan failed to execute."); if (plan) {
// Run cached plan.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
} else {
// Run fallback kernel.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
} }
return; return;
} }
// Build operation graph. // There is no reliable way to deduce the proper cuDNN backend for the
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16) // convolution, so we make a best guess and then try.
? CUDNN_DATA_FLOAT SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
: cudnn_type; if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
auto stride = convert_vector<int64_t>(kernel_strides_); try_backends.push_back(CONV_BACKWARD_INPUT);
auto padding_lo = convert_vector<int64_t>(padding_lo_); } else {
auto padding_hi = convert_vector<int64_t>(padding_hi_); // Otherwise it could be backward weight convolution or forward convolution,
auto dilation = convert_vector<int64_t>(kernel_dilation_); // mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
auto conv_desc = cudnn_frontend::ConvDescBuilder() // usually have |in| and |wt| transposed.
.setDataType(compute_data_type) if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
.setMathMode(CUDNN_CROSS_CORRELATION) wt.shape(2) > out.shape(2)) {
.setNDims(stride.size()) try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
.setStrides(stride.size(), stride.data()) } else {
.setPrePadding(padding_lo.size(), padding_lo.data()) try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
.setPostPadding(padding_hi.size(), padding_hi.data()) }
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
auto op_graph_tag = op_graph.getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return;
} }
// Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph); // Try to build op graph.
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) { cudnnBackendDescriptorType_t backend_type;
return; std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_conv_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
} }
throw std::runtime_error("Unable to find an engine for convolution.");
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
// Use fallback kernel for settings not supported by cuDNN.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,126 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
template <int NDIM>
struct ConvParams {
int N; // Batch size
int C; // In channels
int O; // Out channels
int strides[NDIM];
int padding[NDIM];
int kernel_dilation[NDIM];
int input_dilation[NDIM];
int groups;
bool flip;
int in_spatial_dims[NDIM];
int wt_spatial_dims[NDIM];
int out_spatial_dims[NDIM];
int64_t in_strides[NDIM + 2];
ConvParams(
const array& in,
const array& wt,
const array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip)
: N(in.shape(0)),
C(in.shape(-1)),
O(wt.shape(0)),
groups(groups),
flip(flip) {
std::copy_n(strides.begin(), NDIM, this->strides);
std::copy_n(padding.begin(), NDIM, this->padding);
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
}
};
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s);
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s);
inline void gemm_conv(
cu::CommandEncoder& encoder,
array in,
array wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
if (groups == 1) {
gemm_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
flip,
s);
} else {
gemm_grouped_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip,
s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_unfold_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array unfold_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = mat_K / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_unfold_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
int mat_N = params.O; // O
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded =
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
wt_reshaped.copy_shared_buffer(
wt,
{1, mat_K},
{false, false, /* col_contiguous */ true},
wt.data_size());
// Single batch.
Shape batch_shape{1};
Strides a_batch_strides{0};
Strides b_batch_strides{0};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
1, // groups
flip);
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,231 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_grouped_unfold_transpose_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + tid.y * (filter_size / params.C);
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
int wt_stride = 1;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
out += index_wt * wt_stride;
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
wt_stride *= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array grouped_unfold_transpose_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = (mat_K * params.groups) / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_grouped_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int C_per_group = params.C / params.groups;
int O_per_group = params.O / params.groups;
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
int mat_N = O_per_group; // O_per_group
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
array wt_view(
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
// Batch with size of groups.
Shape batch_shape{params.groups};
Strides a_batch_strides{mat_K};
Strides b_batch_strides{mat_N * mat_K};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K * params.groups, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.set_out(
out.dtype(),
false, // out_transposed
mat_M, // out_rows
mat_N, // out_cols
mat_N * params.groups, // out_ld
params.groups, // batch_count
mat_N); // batch_stride
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip);
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -76,6 +76,7 @@ void copy_contiguous(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in.data<InType>() + in_offset, in.data<InType>() + in_offset,
out.data<OutType>() + out_offset, out.data<OutType>() + out_offset,
out.data_size()); out.data_size());

View File

@@ -10,37 +10,80 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM> template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_gg_nd( __global__ void copy_gg_nd(
const In* in, const In* in,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in, const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) { const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>( IdxT index_rest =
index, shape.data(), strides_in.data(), strides_out.data()); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
out[idx_out] = CastOp<In, Out>{}(in[idx_in]); if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[NDIM - 1];
auto in_stride_x = strides_in[NDIM - 1];
auto out_stride_x = strides_out[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data());
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
} }
template <typename In, typename Out, typename IdxT> template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_gg( __global__ void copy_gg(
const In* in, const In* in,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out, const __grid_constant__ Strides strides_out,
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
auto [idx_in, idx_out] = elem_to_loc( IdxT index_rest =
index, shape.data(), strides_in.data(), strides_out.data(), ndim); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
out[idx_out] = CastOp<In, Out>{}(in[idx_in]); if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[ndim - 1];
auto in_stride_x = strides_in[ndim - 1];
auto out_stride_x = strides_out[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data(),
ndim);
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
} }
} // namespace cu } // namespace cu
@@ -69,31 +112,52 @@ void copy_general(
size_t data_size = 1; size_t data_size = 1;
for (auto& s : shape) for (auto& s : shape)
data_size *= s; data_size *= s;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = data_size / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) { dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto [num_blocks, block_dims] = auto kernel =
get_launch_args(data_size, shape, out.strides(), large()); cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, rest,
const_param<ndim_constant()>(shape), const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in), const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out)); const_param<ndim_constant()>(strides_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto [num_blocks, block_dims] = auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
get_launch_args(data_size, shape, out.strides(), large()); if (work_per_thread == 4) {
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::copy_gg<InType, OutType, IdxT>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, rest,
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
const_param(strides_out), const_param(strides_out),

View File

@@ -83,6 +83,7 @@ void copy_general_dynamic(
dims_constant()>, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
@@ -98,6 +99,7 @@ void copy_general_dynamic(
cu::copy_gg_dynamic<InType, OutType, IdxT>, cu::copy_gg_dynamic<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),

View File

@@ -10,33 +10,67 @@ namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM> template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_g_nd( __global__ void copy_g_nd(
const In* in, const In* in,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape, const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) { const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data()); IdxT index_rest =
out[index] = CastOp<In, Out>{}(in[idx_in]); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[NDIM - 1];
auto stride_x = strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
template <typename In, typename Out, typename IdxT> template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_g( __global__ void copy_g(
const In* in, const In* in,
Out* out, Out* out,
IdxT size, IdxT size_rest,
const __grid_constant__ Shape shape, const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides,
int ndim) { int ndim) {
IdxT index = cg::this_grid().thread_rank(); auto block = cg::this_thread_block();
if (index < size) { auto grid = cg::this_grid();
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim); IdxT index_rest =
out[index] = CastOp<In, Out>{}(in[idx_in]); grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
} }
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
} }
} // namespace cu } // namespace cu
@@ -61,28 +95,49 @@ void copy_general_input(
const InType* in_ptr = in.data<InType>() + offset_in; const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out; OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large()); auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in)); const_param<dims_constant()>(strides_in));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large()); auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node( encoder.add_kernel_node(
cu::copy_g<InType, OutType, IdxT>, kernel,
num_blocks, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), rest,
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
ndim); ndim);

View File

@@ -0,0 +1,272 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
encoder.add_temporary(workspace);
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -0,0 +1,164 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
//
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
} // namespace mlx::core

View File

@@ -0,0 +1,379 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); graph.end_capture(enc.stream());
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) { if (discard) {
return; return;
} }
@@ -184,20 +182,11 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
} }
} }
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CommandEncoder::CommandEncoder(Device& d)
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); : device_(d),
} stream_(d),
graph_(d),
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) { graph_cache_(cuda_graph_cache_size()) {}
for (auto& [_, graph_exec] : graphs) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
}
graphs.clear();
}
CommandEncoder::~CommandEncoder() {
clear_graphs(graph_cache_);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
@@ -224,12 +213,14 @@ void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params) { void** params) {
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim; kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params); add_kernel_node(kernel_params);
} }
@@ -237,6 +228,7 @@ void CommandEncoder::add_kernel_node(
CUfunction func, CUfunction func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params) { void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
@@ -247,6 +239,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y; kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z; kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params); add_kernel_node(kernel_params);
} }
@@ -269,6 +262,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
} }
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }
@@ -285,7 +279,7 @@ void CommandEncoder::commit() {
graph_key_ += "."; graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_); graph_key_ += std::to_string(empty_node_count_);
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) { if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result; cudaGraphExecUpdateResult update_result;
@@ -299,31 +293,24 @@ void CommandEncoder::commit() {
#endif // CUDART_VERSION >= 12000 #endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) { if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error cudaGetLastError(); // reset error
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); graph_exec.reset();
graph_exec = nullptr;
} }
} }
if (graph_exec == nullptr) { if (graph_exec == nullptr) {
CHECK_CUDA_ERROR( graph_exec.instantiate(graph_);
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
} }
device_.make_current(); device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy
if (graph_cache_.size() > cuda_graph_cache_size()) {
clear_graphs(graph_cache_);
}
// Reset state // Reset state
node_count_ = 0; node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
to_nodes_.clear(); to_nodes_.clear();
graph_key_.clear(); graph_key_.clear();
node_map_.clear(); node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); graph_ = CudaGraph(device_);
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h" #include "mlx/stream.h"
@@ -20,7 +21,7 @@ class CommandEncoder {
struct CaptureContext { struct CaptureContext {
CaptureContext(CommandEncoder& enc); CaptureContext(CommandEncoder& enc);
~CaptureContext(); ~CaptureContext();
cudaGraph_t graph; CudaGraph graph;
CommandEncoder& enc; CommandEncoder& enc;
bool discard{false}; bool discard{false};
}; };
@@ -31,7 +32,6 @@ class CommandEncoder {
}; };
explicit CommandEncoder(Device& d); explicit CommandEncoder(Device& d);
~CommandEncoder();
CommandEncoder(const CommandEncoder&) = delete; CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -47,25 +47,34 @@ class CommandEncoder {
void set_output_array(const array& arr); void set_output_array(const array& arr);
template <typename F, typename... Params> template <typename F, typename... Params>
void void add_kernel_node(
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
constexpr size_t num = sizeof...(Params); constexpr size_t num = sizeof...(Params);
void* ptrs[num]; void* ptrs[num];
size_t i = 0; size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }( ([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)), std::forward<Params>(params)),
...); ...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs); add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
} }
void add_kernel_node( void add_kernel_node(
CUfunction func, CUfunction func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params); void** params);
void void add_kernel_node(
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
// Low-level graph helpers. // Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params); void add_kernel_node(const cudaKernelNodeParams& params);
@@ -106,7 +115,7 @@ class CommandEncoder {
Device& device_; Device& device_;
CudaStream stream_; CudaStream stream_;
cudaGraph_t graph_; CudaGraph graph_;
Worker worker_; Worker worker_;
char node_count_{0}; char node_count_{0};
char graph_node_count_{0}; char graph_node_count_{0};
@@ -117,7 +126,7 @@ class CommandEncoder {
std::string graph_key_; std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_; std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_; std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_; LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_; std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_; std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_; std::unordered_map<std::uintptr_t, GraphNode> node_map_;

View File

@@ -43,10 +43,18 @@ struct alignas(sizeof(T) * N) AlignedVector {
}; };
template <int N, typename T> template <int N, typename T>
inline __device__ bool is_aligned(T* x) { inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0; return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
} }
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T> template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector( inline __device__ AlignedVector<T, N> load_vector(
const T* ptr, const T* ptr,
@@ -101,6 +109,13 @@ inline __device__ AlignedVector<T, N> load_vector(
} }
} }
template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
template <int N, typename T> template <int N, typename T>
inline __device__ void inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) { store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
@@ -131,19 +146,22 @@ inline __device__ void store_vector(
} }
} }
// Helper for accessing strided data. template <int N, typename T, typename SizeT>
template <typename T> inline __device__ void store_vector(
struct StridedIterator { T* ptr,
T it; uint32_t offset,
int64_t stride; const AlignedVector<T, N>& vec,
SizeT size,
__host__ __device__ StridedIterator(T it, int64_t stride) int64_t stride) {
: it(it), stride(stride) {} if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
__host__ __device__ auto operator[](int i) const { to[offset] = vec;
return it[i * stride]; } else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
} }
}; }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils

View File

@@ -36,18 +36,15 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr()); // Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
} }
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr()); encoder.add_temporary(s);
} }
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit(); encoder.maybe_commit();
} }

View File

@@ -1,206 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,10 +7,12 @@
#include <fmt/format.h> #include <fmt/format.h>
namespace mlx::core::cu { namespace mlx::core {
namespace {
struct CublasPreference { struct CublasPreference {
CublasPreference(Device& device) { CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+: // for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -33,7 +35,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
}; };
cublasLtMatmulPreference_t cublas_preference(Device& device) { cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
static CublasPreference pref(device); static CublasPreference pref(device);
return pref.pref_; return pref.pref_;
} }
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F; return CUDA_C_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
} }
} }
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc; return desc;
} }
Matmul::Matmul( } // namespace
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -155,8 +159,8 @@ Matmul::Matmul(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
} }
Matmul::Matmul( CublasGemm::CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -171,7 +175,7 @@ Matmul::Matmul(
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride) int64_t c_batch_stride)
: Matmul( : CublasGemm(
device, device,
dtype, dtype,
a_transposed, a_transposed,
@@ -190,7 +194,7 @@ Matmul::Matmul(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
Matmul::~Matmul() { CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -198,7 +202,92 @@ Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
} }
void Matmul::run_impl( void CublasGemm::set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -213,7 +302,7 @@ void Matmul::run_impl(
matmul_desc_, matmul_desc_,
a_desc_, a_desc_,
b_desc_, b_desc_,
out_desc_, // TODO should that be c_desc is it's set? c ? c_desc_ : out_desc_,
out_desc_, out_desc_,
pref_, pref_,
1, 1,
@@ -226,8 +315,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
allocator::malloc(heuristic_.workspaceSize), allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);
@@ -254,29 +345,4 @@ void Matmul::run_impl(
encoder.stream())); encoder.stream()));
} }
void Matmul::run( } // namespace mlx::core
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <optional>
namespace mlx::core::cu { namespace mlx::core {
class Matmul {
class CublasGemm {
public: public:
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride, int64_t a_batch_stride,
int64_t b_batch_stride); int64_t b_batch_stride);
Matmul( CublasGemm(
Device& device, cu::Device& device,
Dtype dtype, Dtype dtype,
bool a_transposed, bool a_transposed,
uint64_t a_rows, uint64_t a_rows,
@@ -42,25 +42,50 @@ class Matmul {
int64_t b_batch_stride, int64_t b_batch_stride,
int64_t c_batch_stride); int64_t c_batch_stride);
~Matmul(); ~CublasGemm();
// The output's descriptor is inferred from inputs by default, use this method
// for unusual output.
void set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const std::optional<array>& c = std::nullopt, const Shape& batch_shape,
float alpha = 1, const Strides& a_batch_strides,
float beta = 0); const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides); const Strides& b_batch_strides);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -68,15 +93,14 @@ class Matmul {
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta); float beta);
private: void execute(
void run_impl(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
void* out, void* out,
const void* a, const void* a,
@@ -97,4 +121,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_; cublasLtMatmulHeuristicResult_t heuristic_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu { namespace mlx::core {
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const Strides& b_batch_strides) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void Matmul::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void Matmul::run_batched(
} }
} }
void Matmul::run_batched( void CublasGemm::run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
const array& a, const array& a,
const array& b, const array& b,
const array& c, const array& c,
const mlx::core::Shape& batch_shape, const Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides, const Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides, const Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
encoder.set_input_array(a); encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void Matmul::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context(); auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
run_impl( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void Matmul::run_batched(
} }
} }
} // namespace mlx::core::cu } // namespace mlx::core

View File

@@ -0,0 +1,327 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -0,0 +1,396 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cute/tensor.hpp>
#include <cutlass/arch/arch.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <iostream>
namespace mlx::core::cu {
namespace {
using namespace cute;
using bf16 = cute::bfloat16_t;
template <typename Kernel>
void configure_matmul(Kernel kernel, int smem_size) {
static bool initialized = false;
if (!initialized) {
initialized = true;
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
}
template <bool transpose, typename Tiler>
constexpr int get_feature_size(Tiler smem) {
int feature_size = (transpose) ? size<0>(smem) : size<1>(smem);
return (feature_size >= 64) ? 64 : feature_size;
}
constexpr int constexpr_log2(int x) {
return (x > 0) ? 1 + constexpr_log2(x >> 1) : -1;
}
template <int feature_size, int itemsize, int copy_bits>
constexpr int get_swizzle_bits() {
constexpr int swizzle_bits =
constexpr_log2(feature_size * itemsize / copy_bits);
return (swizzle_bits > 3) ? 3 : swizzle_bits;
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled =
make_composed_layout(Swizzle<swizzle_bits, 3, 3>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <int itemsize, bool transpose, int copy_bits, typename Tiler>
constexpr auto make_result_smem_layout(Tiler smem) {
constexpr int feature_size = get_feature_size<transpose>(smem);
constexpr int swizzle_bits =
get_swizzle_bits<feature_size, itemsize, copy_bits>();
using F = Int<feature_size>;
using BaseLayout = std::conditional_t<
transpose,
Layout<cute::Shape<F, _8>, cute::Stride<_1, F>>,
Layout<cute::Shape<_8, F>, cute::Stride<F, _1>>>;
auto swizzled = make_composed_layout(
Swizzle<transpose ? 0 : swizzle_bits, 3, 4>{}, 0, BaseLayout{});
return tile_to_shape(swizzled, smem);
}
template <
int num_threads,
int itemsize,
bool transpose,
int copy_bits,
typename Copier,
typename Tiler>
constexpr auto make_tiled_copy(Copier copy_op, Tiler smem) {
constexpr int num_elements = copy_bits / itemsize;
constexpr int feature_size = transpose ? size<0>(smem) : size<1>(smem);
constexpr int copies_per_feature = feature_size / num_elements;
using E = Int<num_elements>;
using C = Int<copies_per_feature>;
using R = Int<num_threads / copies_per_feature>;
using ThreadLayout = std::conditional_t<
transpose,
Layout<cute::Shape<C, R>, cute::Stride<_1, C>>,
Layout<cute::Shape<R, C>, cute::Stride<C, _1>>>;
using ValueLayout = std::conditional_t<
transpose,
Layout<cute::Shape<E, _1>>,
Layout<cute::Shape<_1, E>>>;
return make_tiled_copy(copy_op, ThreadLayout{}, ValueLayout{});
}
template <int rasterization_factor>
__device__ inline int2 raster_tile(int x, int y) {
return {
x / rasterization_factor,
(x % rasterization_factor) + y * rasterization_factor};
}
template <
typename T,
typename SLayoutA,
typename SLayoutB,
typename SLayoutC,
typename CopyA,
typename CopyB,
typename CopyC,
typename MMA,
int rasterization_factor>
__global__ static __launch_bounds__(decltype(size(MMA{}))::value) void matmul_kernel(
const T* __restrict__ A,
const T* __restrict__ B,
T* __restrict__ C,
SLayoutA SA,
SLayoutB SB,
SLayoutC SC,
CopyA copy_a,
CopyB copy_b,
CopyC copy_c,
MMA mma,
int M,
int N,
int K) {
constexpr auto BM = size<0>(SA);
constexpr auto BN = size<0>(SB);
constexpr auto BK = size<1>(SA);
constexpr auto PIPE = size<2>(SA);
const int2 tile = raster_tile<rasterization_factor>(blockIdx.x, blockIdx.y);
const int blocks_m = ceil_div(M, BM);
const int blocks_n = ceil_div(N, BN);
// Exit early if the tile is OOB
if (tile.x >= blocks_m || tile.y >= blocks_n) {
return;
}
// Make the full tensors
Tensor full_A =
make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, _1{}));
Tensor full_B =
make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(K, _1{}));
Tensor full_C =
make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, _1{}));
// Partition the tensors into tiles and select the ones for this threadblock
Tensor local_A =
local_tile(full_A, make_shape(BM, BK), make_coord(tile.x, _));
Tensor local_B =
local_tile(full_B, make_shape(BN, BK), make_coord(tile.y, _));
Tensor local_C =
local_tile(full_C, make_shape(BM, BN), make_coord(tile.x, tile.y));
// Make shared memory tensors
extern __shared__ char shared_memory[];
T* shared_A_ptr = reinterpret_cast<T*>(shared_memory);
T* shared_B_ptr =
reinterpret_cast<T*>(shared_memory + cosize(SA) * sizeof(T));
T* shared_C_ptr = reinterpret_cast<T*>(shared_memory);
Tensor shared_A = make_tensor(make_smem_ptr(shared_A_ptr), SA);
Tensor shared_B = make_tensor(make_smem_ptr(shared_B_ptr), SB);
Tensor shared_C = make_tensor(make_smem_ptr(shared_C_ptr), SC);
// Get the copies that correspond to this thread
auto thread_copy_a = copy_a.get_slice(threadIdx.x);
Tensor local_A_src = thread_copy_a.partition_S(local_A);
Tensor local_A_dst = thread_copy_a.partition_D(shared_A);
auto thread_copy_b = copy_b.get_slice(threadIdx.x);
Tensor local_B_src = thread_copy_a.partition_S(local_B);
Tensor local_B_dst = thread_copy_a.partition_D(shared_B);
auto thread_copy_c = copy_c.get_slice(threadIdx.x);
Tensor local_C_src = thread_copy_c.partition_S(shared_C);
Tensor local_C_dst = thread_copy_c.partition_D(local_C);
// Start fetches
int k_tile_count = size<2>(local_A);
int k_tile_next = 0;
CUTE_UNROLL
for (int k = 0; k < PIPE - 1; k++) {
copy(copy_a, local_A_src(_, _, _, k_tile_next), local_A_dst(_, _, _, k));
copy(copy_b, local_B_src(_, _, _, k_tile_next), local_B_dst(_, _, _, k));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
}
// Get the MMA that corresponds to this thread and allocate registers
auto thread_mma = mma.get_slice(threadIdx.x);
Tensor mma_shared_A = thread_mma.partition_A(shared_A);
Tensor mma_shared_B = thread_mma.partition_B(shared_B);
Tensor mma_shared_C = thread_mma.partition_C(shared_C);
Tensor mma_global_C = thread_mma.partition_C(local_C);
Tensor mma_frag_A = mma.make_fragment_A(mma_shared_A(_, _, _, 0));
Tensor mma_frag_B = mma.make_fragment_B(mma_shared_B(_, _, _, 0));
Tensor mma_frag_C = mma.make_fragment_C(mma_global_C);
clear(mma_frag_C);
// Make shared to register copies
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_a;
Copy_Atom<SM75_U32x4_LDSM_N, bf16> s2r_atom_b;
auto s2r_copy_a = make_tiled_copy_A(s2r_atom_a, mma);
auto s2r_thread_copy_a = s2r_copy_a.get_slice(threadIdx.x);
auto s2r_copy_b = make_tiled_copy_B(s2r_atom_b, mma);
auto s2r_thread_copy_b = s2r_copy_b.get_slice(threadIdx.x);
Tensor mma_A_src = s2r_thread_copy_a.partition_S(shared_A);
Tensor mma_A_dst = s2r_thread_copy_a.retile_D(mma_frag_A);
Tensor mma_B_src = s2r_thread_copy_b.partition_S(shared_B);
Tensor mma_B_dst = s2r_thread_copy_b.retile_D(mma_frag_B);
constexpr auto RPIPE = size<2>(mma_shared_A);
int smem_read = 0;
int smem_write = PIPE - 1;
Tensor mma_A_src_p = mma_A_src(_, _, _, smem_read);
Tensor mma_B_src_p = mma_B_src(_, _, _, smem_read);
// Start the register pipeline
if constexpr (RPIPE > 1) {
cp_async_wait<PIPE - 2>();
__syncthreads();
copy(s2r_copy_a, mma_A_src_p(_, _, Int<0>{}), mma_A_dst(_, _, Int<0>{}));
copy(s2r_copy_b, mma_B_src_p(_, _, Int<0>{}), mma_B_dst(_, _, Int<0>{}));
}
CUTE_NO_UNROLL
while (k_tile_count > -(PIPE - 1)) {
CUTE_UNROLL
for (int k_block = 0; k_block < RPIPE; k_block++) {
if (k_block == RPIPE - 1) {
mma_A_src_p = mma_A_src(_, _, _, smem_read);
mma_B_src_p = mma_B_src(_, _, _, smem_read);
cp_async_wait<PIPE - 2>();
__syncthreads();
}
// Load the next register tile
auto k_block_next = (k_block + 1) % RPIPE;
copy(
s2r_copy_a,
mma_A_src_p(_, _, k_block_next),
mma_A_dst(_, _, k_block_next));
copy(
s2r_copy_b,
mma_B_src_p(_, _, k_block_next),
mma_B_dst(_, _, k_block_next));
if (k_block == 0) {
copy(
copy_a,
local_A_src(_, _, _, k_tile_next),
local_A_dst(_, _, _, smem_write));
copy(
copy_b,
local_B_src(_, _, _, k_tile_next),
local_B_dst(_, _, _, smem_write));
cp_async_fence();
k_tile_count--;
k_tile_next += (k_tile_count > 0);
smem_write = smem_read;
smem_read = (smem_read == PIPE - 1) ? 0 : (smem_read + 1);
}
gemm(
mma,
mma_frag_A(_, _, k_block),
mma_frag_B(_, _, k_block),
mma_frag_C);
}
}
copy(mma_frag_C, mma_shared_C);
__syncthreads();
copy(copy_c, local_C_src, local_C_dst);
// if (threadIdx.x == 0) {
// print("fC: "); print(mma_frag_C); print("\n");
// print("sC: "); print(mma_shared_C); print("\n");
// print("dC: "); print(local_C_dst); print("\n");
//
// print(s2r_atom_a); print("\n");
// }
}
} // namespace
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (std::is_same_v<DataType, __nv_bfloat16>) {
using namespace cute;
// Tile definitions
auto BM = Int<128>{};
auto BN = Int<128>{};
auto BK = Int<64>{};
auto BP = Int<3>{};
auto GM = Int<8>{};
// Thread definitions
using TM = Int<2>;
using TN = Int<2>;
using TK = Int<1>;
constexpr int num_threads = TM::value * TN::value * 32;
auto SA = make_smem_layout<16, false, 128>(make_shape(BM, BK, BP));
auto SB = make_smem_layout<16, false, 128>(make_shape(BN, BK, BP));
auto SC = make_result_smem_layout<16, false, 128>(make_shape(BM, BN));
constexpr auto smem_size = (cosize(SA) + cosize(SB)) * sizeof(bf16);
auto async_copy_op =
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, bf16>{};
auto tiled_copy_a = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BM, BK));
auto tiled_copy_b = make_tiled_copy<num_threads, 16, false, 128>(
async_copy_op, make_shape(BN, BK));
auto sync_copy_op = Copy_Atom<UniversalCopy<uint128_t>, bf16>{};
auto tiled_copy_c = make_tiled_copy<num_threads, 16, false, 128>(
sync_copy_op, make_shape(BM, BN));
auto mma_op = SM80_16x8x16_F32BF16BF16F32_TN{};
auto tiled_mma = make_tiled_mma(
mma_op, Layout<cute::Shape<TM, TN, TK>>{}, Tile<_32, _32, _16>{});
auto kernel = matmul_kernel<
bf16,
decltype(SA),
decltype(SB),
decltype(SC),
decltype(tiled_copy_a),
decltype(tiled_copy_b),
decltype(tiled_copy_c),
decltype(tiled_mma),
GM.value>;
configure_matmul(kernel, smem_size);
dim3 block(size(tiled_mma));
dim3 grid(
size(ceil_div(M, BM) * GM), size(ceil_div(ceil_div(N, BN), GM)));
enc.add_kernel_node(
kernel,
grid,
block,
smem_size,
a.data<bf16>(),
b.data<bf16>(),
out.data<bf16>(),
SA,
SB,
SC,
tiled_copy_a,
tiled_copy_b,
tiled_copy_c,
tiled_mma,
M,
N,
K);
} else {
throw std::runtime_error("Only bfloat16 supported");
}
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void cutlass_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -27,8 +27,9 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f; float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols; for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) { col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0); auto local_mat =
auto local_vec = load_vector<n_per_thread>(vec + col, 0); unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll #pragma unroll
for (int j = 0; j < n_per_thread; ++j) { for (int j = 0; j < n_per_thread; ++j) {
sum += sum +=
@@ -127,9 +128,13 @@ void gemv(
rows = M; rows = M;
} }
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t = 4; int n_per_t;
while (K % (n_per_t * WARP_SIZE) != 0) { if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t >>= 1; n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
} }
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) { if (batch_count == 1) {
@@ -138,6 +143,7 @@ void gemv(
kernel, kernel,
num_blocks_x, num_blocks_x,
block_dims, block_dims,
0,
mat, mat,
vec, vec,
out.data<DataType>(), out.data<DataType>(),
@@ -149,6 +155,7 @@ void gemv(
kernel, kernel,
dim3{num_blocks_x, batch_count}, dim3{num_blocks_x, batch_count},
block_dims, block_dims,
0,
mat, mat,
vec, vec,
out.data<DataType>(), out.data<DataType>(),

View File

@@ -0,0 +1,69 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/steel/gemm.cuh"
#include "mlx/dtype_utils.h"
#include <iostream>
namespace mlx::core::cu {
namespace {
template <typename Kernel>
static void configure_smem(Kernel kernel, int SM) {
static bool done = false;
if (done) {
return;
}
std::cout << "configuring" << std::endl;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SM);
cudaFuncSetAttribute(
kernel,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared);
done = true;
}
} // namespace
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc) {
enc.set_input_array(a);
enc.set_input_array(b);
enc.set_output_array(out);
dispatch_float_types(a.dtype(), "simple_gemm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int PIPE = 3;
constexpr int SM = PIPE * sizeof(DataType) * (BM * BK + BN * BK);
constexpr int WM = 2;
constexpr int WN = 4;
auto kernel = ab_t_aligned<DataType, BM, BN, BK, WM, WN, PIPE>;
configure_smem(kernel, SM);
dim3 grid(N / BN, M / BM);
enc.add_kernel_node(
kernel,
grid,
WM * WN * WARP_SIZE,
SM,
a.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
N,
K);
});
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,18 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
void simple_gemm(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
cu::CommandEncoder& enc);
}

View File

@@ -29,12 +29,12 @@ void append_indices_arg(
const std::vector<array>& inputs, const std::vector<array>& inputs,
int nidx, int nidx,
int idx_ndim) { int idx_ndim) {
std::vector<const void*> indices(nidx); SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>(); indices[i] = inputs[i + 1].data<void>();
} }
args.append(std::move(indices)); args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim); SmallVector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
inputs[i + 1].shape().begin(), inputs[i + 1].shape().begin(),
@@ -42,7 +42,7 @@ void append_indices_arg(
indices_shape.data() + i * idx_ndim); indices_shape.data() + i * idx_ndim);
} }
args.append(std::move(indices_shape)); args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim); SmallVector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy_n( std::copy_n(
inputs[i + 1].strides().begin(), inputs[i + 1].strides().begin(),
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_gather, std::move(kernel_names)); return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim()); args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_); args.append_ndim(slice_sizes_);
args.append(slice_size); args.append(slice_size);
args.append(axes_); args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(out, large); auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t")); large ? "int64_t" : "int32_t"));
} }
} }
return std::make_pair(jit_source_scatter, std::move(kernel_names)); return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
}); });
cu::KernelArgs args; cu::KernelArgs args;
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape()); args.append_ndim(out.shape());
args.append_ndim(out.strides()); args.append_ndim(out.strides());
args.append<int32_t>(out.ndim()); args.append<int32_t>(out.ndim());
args.append(axes_); args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(upd, large); auto [num_blocks, block_dims] = get_launch_args(upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_gather_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;
@@ -318,7 +319,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
} }
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); return std::make_tuple(
false, jit_source_scatter_axis, std::move(kernel_names));
}); });
size_t idx_size_pre = 1; size_t idx_size_pre = 1;
@@ -422,7 +424,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -101,8 +101,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::string& ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
return false; return false;
} }
@@ -117,15 +117,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) { if (!ptx_file.good()) {
return false; return false;
} }
ptx->resize(ptx_size); ptx.resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size); ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
if (tab != std::string::npos) { if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1)); ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
} }
} }
return true; return true;
@@ -135,7 +135,7 @@ bool read_cached_ptx(
void write_cached_ptx( void write_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) { const std::string& source_code) {
if (cache_dir.empty()) { if (cache_dir.empty()) {
@@ -217,85 +217,85 @@ constexpr const char* g_headers[] = {
jit_source_utils, jit_source_utils,
}; };
} // namespace void compile(
JitModule::JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const std::string& source,
// Check cache. const std::vector<std::string>& kernel_names,
std::vector<char> ptx; std::string& ptx,
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) { // Create the program
// Create program. nvrtcProgram prog;
auto [source_code, kernel_names] = builder(); CHECK_NVRTC_ERROR(nvrtcCreateProgram(
nvrtcProgram prog; &prog,
CHECK_NVRTC_ERROR(nvrtcCreateProgram( source.c_str(),
&prog, (module_name + ".cu").c_str(),
source_code.c_str(), std::size(g_headers),
(module_name + ".cu").c_str(), g_headers,
std::size(g_headers), g_include_names));
g_headers, std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
g_include_names)); &prog,
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer( [](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
&prog, for (const auto& name : kernel_names) {
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); }); CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size, 0);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
} }
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
}
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -312,21 +312,69 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel; kernels[name] = std::make_pair(kernel, false);
} }
} }
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() { JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel(const std::string& kernel_name) { CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) { if (it == kernels_.end()) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name)); fmt::format("There is no kernel named {}.", kernel_name));
} }
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}
return it->second.first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
@@ -337,11 +385,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder) { const KernelBuilder& builder,
bool cache) {
auto& map = get_jit_module_cache(); auto& map = get_jit_module_cache();
auto it = map.find(name); auto it = map.find(name);
if (it == map.end()) { if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first; it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
} }
return it->second; return it->second;
} }

View File

@@ -19,7 +19,8 @@ namespace mlx::core::cu {
class Device; class Device;
using KernelBuilderResult = std::pair< using KernelBuilderResult = std::tuple<
/* precompiled */ bool,
/* source code */ std::string, /* source code */ std::string,
/* kernel names */ std::vector<std::string>>; /* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>; using KernelBuilder = std::function<KernelBuilderResult()>;
@@ -40,19 +41,14 @@ struct KernelArgs {
} }
template <typename T> template <typename T>
void append(std::vector<T> vec) { void append(SmallVector<T> vec) {
if (vec.empty()) { storage_.emplace_back(std::move(vec));
// The nullptr can not be used as arg, pass something not null. append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
append(std::monostate{});
} else {
append_ptr(vec.data());
storage_.emplace_back(std::move(vec));
}
} }
// Make sure the arg is copied to an array with size of NDIM. // Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T> template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(std::vector<T> vec) { void append_ndim(SmallVector<T> vec) {
if (vec.size() > NDIM) { if (vec.size() > NDIM) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM)); fmt::format("ndim can not be larger than {}.", NDIM));
@@ -68,17 +64,19 @@ struct KernelArgs {
private: private:
std::vector<void*> args_; std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store // The cuGraphAddKernelNode API requires passing pointers to arguments so
// temporary values untill kernel is launched. // store temporary values until the node is created.
using Arg = std::variant< using Arg = std::variant<
std::monostate, std::monostate,
CUdeviceptr, CUdeviceptr,
bool,
int32_t, int32_t,
uint32_t, uint32_t,
int64_t, int64_t,
std::vector<const void*>, float,
std::vector<int32_t>, SmallVector<const void*>,
std::vector<int64_t>>; SmallVector<int32_t>,
SmallVector<int64_t>>;
std::deque<Arg> storage_; std::deque<Arg> storage_;
}; };
@@ -87,16 +85,19 @@ class JitModule {
JitModule( JitModule(
Device& device, Device& device,
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool cache);
~JitModule(); ~JitModule();
JitModule(const JitModule&) = delete; JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete; JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name); CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_; std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();
@@ -104,6 +105,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module( JitModule& get_jit_module(
const mlx::core::Device& device, const mlx::core::Device& device,
const std::string& name, const std::string& name,
const KernelBuilder& builder); const KernelBuilder& builder,
bool use_disk_cache = true);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
// Utility to copy data from vector to array in host. // Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t> template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) { inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
if (vec.size() > NDIM) { if (vec.size() > NDIM) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM)); fmt::format("ndim can not be larger than {}.", NDIM));

View File

@@ -279,6 +279,7 @@ void LayerNorm::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
b.data<DataType>(), b.data<DataType>(),
@@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
g.data<DataType>(), g.data<DataType>(),

View File

@@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
in.data<DataType>(), in.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
axis_size); axis_size);

View File

@@ -2,6 +2,7 @@
#pragma once #pragma once
#include <cstring>
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
@@ -20,7 +21,11 @@ class LRUCache {
using const_iterator = typename list_type::const_iterator; using const_iterator = typename list_type::const_iterator;
using map_type = M<K, iterator>; using map_type = M<K, iterator>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {} explicit LRUCache(size_t capacity) : capacity_(capacity) {
if (capacity == 0) {
throw std::runtime_error("LRUCache requires capacity > 0.");
}
}
size_t size() const { size_t size() const {
return map_.size(); return map_.size();
@@ -84,6 +89,14 @@ class LRUCache {
return vlist_.erase(pos); return vlist_.erase(pos);
} }
V& operator[](const K& key) {
auto it = find(key);
if (it == end()) {
it = emplace(key, V{}).first;
}
return it->second;
}
private: private:
void trim() { void trim() {
while (map_.size() > capacity_) { while (map_.size() > capacity_) {

View File

@@ -3,7 +3,9 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/cutlass_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/gemms/simple_gemm.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -11,8 +13,14 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
int get_test_gemm() {
static int t = env::get_var("MLX_ENABLE_TEST_GEMM", 0);
return t;
}
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
@@ -95,9 +103,21 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 1) {
cu::simple_gemm(a, b, out, M, N, K, encoder);
return;
}
if (M % 512 == 0 && N % 512 == 0 && K % 512 == 0 && !a_transposed &&
b_transposed && batch_count == 1 && get_test_gemm() == 2) {
cu::cutlass_gemm(a, b, out, M, N, K, encoder);
return;
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -111,14 +131,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -186,7 +199,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt
cu::Matmul matmul( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),
a.dtype(), a.dtype(),
a_transposed, a_transposed,
@@ -202,12 +215,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back(), b_batch_strides.back(),
c_batch_strides.back()); c_batch_strides.back());
gemm.run(
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
encoder, encoder,
out, out,
a, a,

View File

@@ -1,11 +1,47 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h" #include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core::cu { namespace mlx::core {
namespace cu {
bool is_available() { bool is_available() {
return false; return false;
} }
} // namespace mlx::core::cu } // namespace cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -6,17 +6,6 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
@@ -52,11 +41,6 @@ NO_GPU(Cholesky)
NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed { namespace distributed {
NO_GPU_MULTI(AllReduce) NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather) NO_GPU_MULTI(AllGather)

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
__global__ void __global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) { affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -240,140 +227,102 @@ __global__ void affine_dequantize(
} }
} // namespace cu } // namespace cu
namespace {
inline array ensure_row_contiguous( void affine_quantize(
const array& x, const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc, cu::CommandEncoder& enc,
const Stream& s) { const Stream& s) {
if (!x.flags().row_contiguous) { // Calculate the number of elements per thread
array x_copy = contiguous_copy_gpu(x, s); int per_thread = group_size_ / WARP_SIZE;
enc.add_temporary(x_copy); size_t size = w.size() / per_thread;
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX; bool large = size > UINT_MAX;
auto grid_shape = w.shape(); auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
if (dequantize_) { enc.set_input_array(w);
grid_shape.back() *= uint8_per_uint32; enc.set_output_array(wq);
} else { enc.set_output_array(scales);
grid_shape.back() /= per_thread; enc.set_output_array(biases);
} dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) { dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) { dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) { auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large); get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node( enc.add_kernel_node(
cu::affine_dequantize<DataType, group_size.value, bits.value>, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
w.data<uint8_t>(), 0,
inputs[1].data<DataType>(), w.data<T>(),
inputs[2].data<DataType>(), wq.data<uint8_t>(),
out.data<DataType>(), scales.data<T>(),
out.size()); biases.data<T>(),
} else { w.size());
auto [num_blocks, block_dims] = });
get_launch_args(size, grid_shape, w.strides(), large); });
enc.add_kernel_node( });
cu::affine_quantize<DataType, group_size.value, bits.value>, }
num_blocks,
block_dims, void affine_dequantize(
w.data<DataType>(), const array& wq,
out.data<uint8_t>(), const array& scales,
outputs[1].data<DataType>(), const array& biases,
outputs[2].data<DataType>(), array& w,
w.size()); int group_size_,
} int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
}); });
}); });
}); });

View File

@@ -0,0 +1,80 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
} // namespace
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbitsc, cu::rbitsc,
grid, grid,
block, block,
0,
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbits, cu::rbits,
grid, grid,
block, block,
0,
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,

View File

@@ -120,6 +120,7 @@ void all_reduce(
kernel, kernel,
blocks, blocks,
threads, threads,
0,
static_cast<T*>(indata), static_cast<T*>(indata),
intermediate.data<U>(), intermediate.data<U>(),
block_step, block_step,
@@ -146,6 +147,7 @@ void all_reduce(
kernel, kernel,
blocks, blocks,
threads, threads,
0,
static_cast<T*>(indata), static_cast<T*>(indata),
out.data<U>(), out.data<U>(),
block_step, block_step,

View File

@@ -230,7 +230,7 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, blocks, indata, out.data<U>(), args); kernel, grid, blocks, 0, indata, out.data<U>(), args);
}); });
}); });
}); });

View File

@@ -41,7 +41,8 @@ void init_reduce(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024; grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size()); encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
}); });
}); });
} }

View File

@@ -269,7 +269,7 @@ void row_reduce_simple(
int size = plan.shape.back(); int size = plan.shape.back();
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), size); kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
}); });
}); });
} }
@@ -322,7 +322,7 @@ void row_reduce_looped(
}); });
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), args); kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
}); });
}); });
} }

View File

@@ -222,6 +222,7 @@ void RMSNorm::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
@@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
g.data<DataType>(), g.data<DataType>(),

View File

@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),

View File

@@ -0,0 +1,781 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
#define PRAGMA_LOOP_UNROLL #pragma unroll
struct AttnParams {
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
};
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_1pass(
const T* Q,
const T* K,
const T* V,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = BN * int(params.K_strides[2]);
const int inner_v_stride = BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -INFINITY;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = max_scores[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
float* partials,
float* sums,
float* maxs,
__grid_constant__ const AttnParams params) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = block_idx * BN + warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s + // Sequence
block_idx; // Block
partials += p_offset * D;
sums += p_offset;
maxs += p_offset;
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -1e9;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
// Write the sum and new max
if (warp_idx == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();
if (warp_idx == 0) {
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[j][lane_idx];
warp.sync();
}
o[i] = ot;
}
block.sync();
}
if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_2(
const float* partials,
const float* sums,
const float* maxs,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
typedef float U;
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int q_seq_idx = blockIdx.y;
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s; // Sequence
partials += p_offset * D + warp_idx * D;
sums += p_offset;
maxs += p_offset;
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = partials[v_per_thread * lane_idx + i];
}
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
} // namespace cu
namespace {
template <typename F>
void dispatch_headdim(int n, F&& f) {
switch (n) {
case 64:
f(std::integral_constant<int, 64>{});
break;
case 96:
f(std::integral_constant<int, 96>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
void sdpa_vector_1pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel =
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
params);
});
});
});
}
void sdpa_vector_2pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
// Allocate the intermediates
int blocks = 32;
Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(o.shape().back());
array intermediate(intermediate_shape, float32, nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
{
auto kernel = cu::
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
dim3 grid_dim(params.H, params.qL, params.B * 32);
dim3 block_dim(8 * 32, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
params);
}
{
auto kernel = cu::
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(intermediate);
encoder.set_input_array(sums);
encoder.set_input_array(maxs);
encoder.set_output_array(o);
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
params);
}
});
});
});
}
void sdpa_vector_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
}
}
} // namespace
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
return has_arr_mask || !supported_config;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {
return arr;
}
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
else {
throw std::runtime_error("Doesn't support matrix yet.");
}
}
} // namespace fast
} // namespace mlx::core

View File

@@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
in.data_size() / axis_size, in.data_size() / axis_size,
block_dim, block_dim,
0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
axis_size); axis_size);
@@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
num_blocks, num_blocks,
block_dim, block_dim,
0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
axis_size, axis_size,

View File

@@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
in.data<DataType>(), in.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
axis_size); axis_size);

View File

@@ -1,6 +1,5 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
@@ -13,7 +12,6 @@
#include <cub/device/device_segmented_sort.cuh> #include <cub/device/device_segmented_sort.cuh>
#include <cassert> #include <cassert>
#include <numeric>
namespace mlx::core { namespace mlx::core {
@@ -27,29 +25,6 @@ struct ModOp {
} }
}; };
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
struct OffsetTransform { struct OffsetTransform {
int nsort; int nsort;

Some files were not shown because too many files have changed in this diff Show More