Compare commits

..

6 Commits

Author SHA1 Message Date
Angelos Katharopoulos
8269c9d02d Support unaligned M 2025-07-23 00:40:27 -07:00
Angelos Katharopoulos
903b40627c Add dynamic shared memory and improve qmm 2025-07-22 23:36:53 -07:00
Angelos Katharopoulos
700f7dcf01 Refactor the matmul a bit 2025-07-21 23:38:21 -07:00
Angelos Katharopoulos
6c60bd1cbf Fixed mma and working dequant 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
a64cc02a0c Somewhat working matmul primitives 2025-07-21 04:47:42 -07:00
Angelos Katharopoulos
346ae5fdb5 Refactor quantized 2025-07-21 04:47:41 -07:00
142 changed files with 1973 additions and 5798 deletions

View File

@@ -7,9 +7,6 @@ parameters:
nightly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -81,24 +78,23 @@ jobs:
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get upgrade -y
pip install --upgrade cmake
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
pip install -e ".[dev]"
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
@@ -106,7 +102,6 @@ jobs:
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
@@ -132,30 +127,33 @@ jobs:
- run:
name: Install dependencies
command: |
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
source env/bin/activate
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
@@ -164,17 +162,16 @@ jobs:
- run:
name: Build example extension
command: |
source .venv/bin/activate
source env/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -183,7 +180,7 @@ jobs:
- run:
name: Build small binary
command: |
source .venv/bin/activate
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -195,60 +192,34 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e .
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
image: linux-cuda-12:2023.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
@@ -344,10 +315,14 @@ jobs:
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get upgrade -y
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo apt-get install -y apt-utils
sudo apt-get install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
@@ -391,27 +366,22 @@ jobs:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
image: linux-cuda-12:2024.11.1
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
@@ -422,6 +392,7 @@ jobs:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
@@ -434,24 +405,19 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- cuda_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
@@ -572,9 +538,6 @@ workflows:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -638,87 +601,3 @@ workflows:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -22,7 +22,7 @@ project(
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
@@ -41,9 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -70,15 +68,6 @@ else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -243,16 +232,12 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
@@ -68,23 +68,18 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
```bash
**With `pip`**:
```
pip install mlx
```
To install the CUDA backend on Linux, run:
**With `conda`**:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
conda install -c conda-forge mlx
```
Checkout the

View File

@@ -1,54 +0,0 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx
To install from PyPI your system must meet the following requirements:
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
@@ -26,21 +26,12 @@ To install from PyPI your system must meet the following requirements:
CUDA
^^^^
MLX has a CUDA backend which you can install with:
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
pip install "mlx[cuda]"
CPU-only (Linux)
^^^^^^^^^^^^^^^^
@@ -49,14 +40,7 @@ For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
pip install "mlx[cpu]"
Troubleshooting
^^^^^^^^^^^^^^^
@@ -271,7 +255,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream>
#include <sstream>
@@ -17,19 +16,6 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
}
// Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.4.0
nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -10,7 +10,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -19,8 +18,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc

View File

@@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -196,11 +196,8 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}

View File

@@ -288,14 +288,6 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;

View File

@@ -377,10 +377,4 @@ void copy_cpu_inplace(
});
}
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -30,7 +30,4 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core

View File

@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
return {contiguous_copy_cpu(arr, stream), true};
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
}
return in;
} else {
array arr_copy = contiguous_copy_cpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_cpu(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>

View File

@@ -87,7 +87,8 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -136,8 +136,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
}
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true);
}
};

View File

@@ -712,7 +712,9 @@ void fast::AffineQuantize::eval_cpu(
if (arr.flags().row_contiguous) {
return std::make_pair(arr, false);
} else {
return std::make_pair(contiguous_copy_cpu(arr, s), true);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};

View File

@@ -250,8 +250,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
in = contiguous_copy_cpu(in, stream());
encoder.add_temporary(in);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_cpu(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -131,7 +131,8 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -8,7 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -333,24 +333,47 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
sort<MLX_GET_TYPE(type_tag)>(out, axis);
});
});
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -6,7 +6,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
@@ -16,22 +15,18 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
@@ -40,7 +35,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
@@ -49,17 +43,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation.
@@ -102,18 +89,11 @@ endif()
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
@@ -145,23 +125,6 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Use the frontend APIs of cuDNN.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
target_link_libraries(mlx PRIVATE cudnn_frontend)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
@@ -169,3 +132,12 @@ target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# Make Thunderkittens available
FetchContent_Declare(
kittens
GIT_REPOSITORY https://github.com/HazyResearch/ThunderKittens.git
GIT_TAG aaab847f430ed313ed466e64b25b9177babd1db8
GIT_SHALLOW TRUE)
FetchContent_MakeAvailable(kittens)
target_include_directories(mlx BEFORE PRIVATE "${kittens_SOURCE_DIR}/include")

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
@@ -24,58 +25,52 @@ constexpr int small_block_size = 8;
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
for (size_t i = 0; i < num_blocks - 1; ++i) {
curr->next = reinterpret_cast<Block*>(
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
CHECK_CUDA_ERROR(cudaFree(buffer_));
}
CudaBuffer* SmallSizePool::malloc() {
void* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
return static_cast<void*>(b);
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
void SmallSizePool::free(void* p) {
auto b = static_cast<Block*>(p);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
bool SmallSizePool::in_pool(void* p) {
return (p >= buffer_) && (p < end_);
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
[this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -97,26 +92,28 @@ Buffer CudaAllocator::malloc(size_t size) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
int64_t mem_to_free =
get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free);
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
buf->data = scalar_pool_.malloc();
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
if (!buf->data) {
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
}
lock.lock();
}
active_memory_ += size;
@@ -126,6 +123,7 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
@@ -140,7 +138,9 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
cuda_free(buf);
lock.unlock();
cuda_free(buf->data);
delete buf;
}
}
@@ -152,13 +152,30 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size;
}
// This must be called with mutex_ aquired
void CudaAllocator::cuda_free(CudaBuffer* buf) {
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
delete buf;
cudaFree(buf);
}
}

View File

@@ -7,10 +7,13 @@
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
@@ -21,14 +24,13 @@ struct CudaBuffer {
class SmallSizePool {
private:
union Block {
struct Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
void* buffer_{nullptr};
Block* next_free_{nullptr};
void* end_{nullptr};
public:
SmallSizePool();
@@ -37,9 +39,9 @@ class SmallSizePool {
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
void* malloc();
void free(void* p);
bool in_pool(void* p);
};
class CudaAllocator : public allocator::Allocator {
@@ -48,6 +50,15 @@ class CudaAllocator : public allocator::Allocator {
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
@@ -58,11 +69,13 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
void cuda_free(CudaBuffer* buf);
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;

View File

@@ -1,55 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
} // namespace mlx::core

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -44,11 +44,8 @@ struct ArgMin {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
@@ -77,11 +74,8 @@ struct ArgMax {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
@@ -112,15 +106,16 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}

View File

@@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b[0]);
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b_vec[i]);
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b[0]);
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -92,7 +92,7 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -128,7 +128,7 @@ __global__ void binary_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
@@ -211,15 +211,12 @@ void binary_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(out, large());
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -232,9 +229,11 @@ void binary_op_gpu_inplace(
const_param<dims_constant()>(b_strides));
});
} else {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::binary_g<Op, InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,
@@ -251,7 +250,8 @@ void binary_op_gpu_inplace(
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -261,7 +261,12 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,

View File

@@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b[0]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -160,7 +160,7 @@ __global__ void binary_two_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
@@ -227,15 +227,16 @@ void binary_two_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -249,10 +250,11 @@ void binary_two_op_gpu_inplace(
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
cu::binary_two_g<Op, InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,
@@ -270,7 +272,8 @@ void binary_two_op_gpu_inplace(
} else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -280,6 +283,7 @@ void binary_two_op_gpu_inplace(
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),

View File

@@ -104,41 +104,10 @@ struct FusedKernelBuilder {
" }\n";
}
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop
if (!contiguous) {
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -153,7 +122,7 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname);
} else if (contiguous) {
value = fmt::format("vec_{}[i]", xname);
value = fmt::format("{}[index]", xname);
} else {
value = fmt::format("{}[{}_idx]", xname, xname);
}
@@ -181,30 +150,25 @@ struct FusedKernelBuilder {
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
}
}
os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
}
os += "}\n";
}
};
@@ -228,15 +192,6 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code.
cu::FusedKernelBuilder builder{
@@ -250,23 +205,28 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
}
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -309,6 +269,7 @@ void Compiled::eval_gpu(
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
@@ -333,7 +294,7 @@ void Compiled::eval_gpu(
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -1,546 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> dilation;
int groups;
bool flip;
uint8_t input_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
};
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
cache(/* capacity */ 128);
return cache;
}
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
array& x,
array& w,
array& y) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
x.data<void>(),
w.data<void>(),
y.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag,
array& x,
array& w,
array& y) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true;
}
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return false;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
if (backend_type == CONV_BACKWARD_INPUT) {
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
}
}
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', x))
.setwDesc(build_tensor('w', w))
.setyDesc(build_tensor('y', y))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
return std::nullopt;
}
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array in,
array wt,
array out,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
}
// cuDNN requires contiguous input.
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(final_out);
if (backend_type == CONV_BACKWARD_WEIGHT) {
// Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype),
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(kernel_strides_),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_dilation_),
groups_,
flip_,
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
return;
}
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
} else {
// Otherwise it could be backward weight convolution or forward convolution,
// mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
// usually have |in| and |wt| transposed.
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
wt.shape(2) > out.shape(2)) {
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
} else {
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
}
}
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
}
if (!op_graph) {
throw std::runtime_error("[conv] Can not build op graph.");
}
// Get ready to execute the graph.
register_args(encoder, backend_type, in, wt, out, out_);
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
auto tag = op_graph->getTag();
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, dtype, *op_graph);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
throw std::runtime_error("[conv] Unable to find a working engine.");
}
} // namespace mlx::core

View File

@@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in[0]);
out_vec.val[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in_vec[i]);
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -65,13 +65,19 @@ void copy_contiguous(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,

View File

@@ -37,7 +37,7 @@ __global__ void copy_gg(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
@@ -71,10 +71,12 @@ void copy_general(
data_size *= s;
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
auto kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node(
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -86,10 +88,11 @@ void copy_general(
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
encoder.add_kernel_node(
cu::copy_gg<InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,

View File

@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
@@ -74,13 +74,12 @@ void copy_general_dynamic(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::copy_gg_dynamic_nd<
InType,
OutType,
IdxT,
dims_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -94,9 +93,11 @@ void copy_general_dynamic(
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::copy_gg_dynamic<InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,

View File

@@ -34,7 +34,7 @@ __global__ void copy_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
@@ -63,9 +63,12 @@ void copy_general_input(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -76,9 +79,11 @@ void copy_general_input(
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::copy_g<InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
@@ -10,23 +9,12 @@
#include <future>
#include <unordered_set>
namespace mlx::core::cu {
namespace {
namespace mlx::core {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
@@ -34,7 +22,7 @@ int cuda_graph_cache_size() {
return cache_size;
}
} // namespace
namespace cu {
Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
@@ -52,18 +40,11 @@ Device::Device(int device) : device_(device) {
}
// The cublasLt handle is used by matmul.
make_current();
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
// Initialize the jit module cache here ensures it is not
// unloaded before any evaluation is done
get_jit_module_cache();
cublasLtCreate(&lt_);
}
Device::~Device() {
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
cublasLtDestroy(lt_);
}
void Device::make_current() {
@@ -85,19 +66,29 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) {
return;
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, enc.graph_, NULL, 0, &params));
enc.insert_graph_dependencies(GraphNode{node, 'K'});
} else {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddChildGraphNode(&node, enc.graph_, NULL, 0, graph));
enc.insert_graph_dependencies(GraphNode{node, 'G'});
}
enc.add_graph_node(graph);
CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
}
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
@@ -184,11 +175,21 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
for (auto& [_, graph_exec] : graphs) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
}
graphs.clear();
}
CommandEncoder::~CommandEncoder() {
clear_graphs(graph_cache_);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
@@ -222,7 +223,10 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
cudaGraphNode_t node;
CHECK_CUDA_ERROR(
cudaGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(
@@ -241,29 +245,13 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
}
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
CHECK_CUDA_ERROR(
cuGraphAddKernelNode(&node, graph_, NULL, 0, &kernel_params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
@@ -280,7 +268,7 @@ void CommandEncoder::commit() {
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
@@ -294,19 +282,25 @@ void CommandEncoder::commit() {
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
graph_exec = nullptr;
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy
if (graph_cache_.size() > cuda_graph_cache_size()) {
clear_graphs(graph_cache_);
}
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
@@ -316,6 +310,7 @@ void CommandEncoder::commit() {
}
// Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_);
}
@@ -324,6 +319,7 @@ void CommandEncoder::synchronize() {
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit();
f.wait();
}
@@ -341,4 +337,6 @@ CommandEncoder& get_command_encoder(Stream s) {
return device(s.device).get_command_encoder(s);
}
} // namespace mlx::core::cu
} // namespace cu
} // namespace mlx::core

View File

@@ -3,13 +3,11 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <cublasLt.h>
#include <cuda.h>
#include <cudnn.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
@@ -23,7 +21,6 @@ class CommandEncoder {
~CaptureContext();
cudaGraph_t graph;
CommandEncoder& enc;
bool discard{false};
};
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc);
@@ -32,6 +29,7 @@ class CommandEncoder {
};
explicit CommandEncoder(Device& d);
~CommandEncoder();
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -76,11 +74,6 @@ class CommandEncoder {
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
@@ -89,10 +82,6 @@ class CommandEncoder {
void maybe_commit();
void commit();
Device& device() {
return device_;
}
CudaStream& stream() {
return stream_;
}
@@ -126,7 +115,7 @@ class CommandEncoder {
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
@@ -157,16 +146,12 @@ class Device {
cublasLtHandle_t lt_handle() const {
return lt_;
}
cudnnHandle_t cudnn_handle() const {
return cudnn_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_;
};

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@@ -49,7 +49,11 @@ inline __device__ void atomic_add(__half* out, __half val) {
}
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {

View File

@@ -32,118 +32,21 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
__device__ T& operator[](int i) {
return val[i];
}
__device__ T operator[](int i) const {
return val[i];
}
};
template <int N, typename T>
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
if (is_aligned<N>(ptr)) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = ptr[offset * N + i];
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N>
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset,
SizeT size,
int64_t stride,
T fallback) {
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] =
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
}
return v;
}
}
template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
if (is_aligned<N>(ptr)) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
///////////////////////////////////////////////////////////////////////////////
@@ -301,8 +204,20 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
}
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
@@ -320,7 +235,7 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,

View File

@@ -1,51 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
namespace distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& input = inputs[0];
auto& output = outputs[0];
auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace distributed
} // namespace mlx::core

View File

@@ -19,6 +19,8 @@ void new_stream(Stream s) {
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
@@ -36,15 +38,18 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
// Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
}

View File

@@ -110,26 +110,24 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr);
delete ptr;
});
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
// Allocate cuda::atomic on managed memory.
Atomic* ac;
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic();
allocator().cuda_free(ptr);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(to_atomic(buf_), value);
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
@@ -140,17 +138,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s);
encoder.commit();
wait(encoder.stream(), value);
encoder.add_completed_handler([buf = buf_]() {});
encoder.add_completed_handler([ac = ac_]() {});
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(to_atomic(buf_), value);
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
@@ -164,18 +162,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s);
encoder.commit();
signal(encoder.stream(), value);
encoder.add_completed_handler([buf = buf_]() {});
encoder.add_completed_handler([ac = ac_]() {});
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return to_atomic(buf_)->load() >= value;
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return to_atomic(buf_)->load();
return ac_->load();
}
} // namespace cu

View File

@@ -2,7 +2,6 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h"
#include <cuda_runtime.h>
@@ -56,8 +55,12 @@ class SharedEvent {
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

View File

@@ -1,73 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu {
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
a_it.step();
b_it.step();
}
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core::cu

View File

@@ -1,208 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -1,284 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
#include <fmt/format.h>
namespace mlx::core::cu {
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()),
pref_(cublas_preference(device)),
M_(a_rows),
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: Matmul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void Matmul::run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha /* = 1 */,
float beta /* = 0 */) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -1,100 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
public:
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha = 1,
float beta = 0);
uint64_t M_;
uint64_t N_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu

View File

@@ -1,173 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto g_idx = block.group_index();
auto t_idx = block.thread_index();
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
}
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void gemv_batched(
const T* mat,
const T* vec,
T* out,
int rows,
int cols,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides mat_batch_strides,
const __grid_constant__ Strides vec_batch_strides,
int batch_ndim) {
auto block = cg::this_thread_block();
auto batch_idx = block.group_index().y;
auto [vec_offset, mat_offset] = elem_to_loc(
batch_idx,
batch_shape.data(),
vec_batch_strides.data(),
mat_batch_strides.data(),
batch_ndim);
gemv_impl<T, rows_per_block, n_per_thread>(
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
}
}
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;
const DataType* vec;
int rows;
int cols = K;
auto mat_strides = const_param(a_batch_strides);
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
});
}
} // namespace mlx::core::cu

View File

@@ -1,24 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -29,12 +29,12 @@ void append_indices_arg(
const std::vector<array>& inputs,
int nidx,
int idx_ndim) {
SmallVector<const void*> indices(nidx);
std::vector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
std::vector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].shape().begin(),
@@ -42,7 +42,7 @@ void append_indices_arg(
indices_shape.data() + i * idx_ndim);
}
args.append(std::move(indices_shape));
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
std::vector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].strides().begin(),
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@@ -128,7 +128,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(out, large);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@@ -229,7 +229,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(upd, large);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
@@ -317,7 +317,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
@@ -421,7 +421,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}

View File

@@ -0,0 +1,121 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@@ -9,6 +9,7 @@
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <fmt/format.h>
#include <nvrtc.h>
@@ -329,16 +330,11 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder) {
auto& map = get_jit_module_cache();
static std::unordered_map<std::string, JitModule> map;
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first;

View File

@@ -40,14 +40,19 @@ struct KernelArgs {
}
template <typename T>
void append(SmallVector<T> vec) {
storage_.emplace_back(std::move(vec));
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
void append(std::vector<T> vec) {
if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null.
append(std::monostate{});
} else {
append_ptr(vec.data());
storage_.emplace_back(std::move(vec));
}
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) {
void append_ndim(std::vector<T> vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@@ -71,9 +76,9 @@ struct KernelArgs {
int32_t,
uint32_t,
int64_t,
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
std::vector<const void*>,
std::vector<int32_t>,
std::vector<int64_t>>;
std::deque<Arg> storage_;
};
@@ -94,8 +99,6 @@ class JitModule {
std::unordered_map<std::string, CUfunction> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,

View File

@@ -30,25 +30,4 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@@ -120,19 +120,53 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
@@ -10,6 +11,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -72,11 +75,9 @@ __global__ void layer_norm(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceT{block, temp}.Sum(sum);
@@ -87,18 +88,11 @@ __global__ void layer_norm(
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
normalizer += t * t;
}
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
@@ -107,15 +101,17 @@ __global__ void layer_norm(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
#pragma unroll
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
store_vector<N_READS>(out, index, xn, axis_size);
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
@@ -148,11 +144,9 @@ __global__ void layer_norm_vjp(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
@@ -162,28 +156,19 @@ __global__ void layer_norm_vjp(
// Normalizer.
float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
@@ -195,10 +180,12 @@ __global__ void layer_norm_vjp(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
@@ -208,9 +195,9 @@ __global__ void layer_norm_vjp(
wn[i] = gi * xi;
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
@@ -271,9 +258,9 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr uint32_t N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
@@ -379,10 +366,10 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr int N_READS = 4;
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm_vjp<
DataType,
has_w_constant.value,

View File

@@ -43,19 +43,20 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
@@ -142,9 +143,9 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr int N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,

View File

@@ -1,159 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
namespace mlx::core {
template <
typename K,
typename V,
template <typename...> typename M = std::unordered_map>
class LRUCache {
public:
using value_type = std::pair<K, V>;
using list_type = std::list<value_type>;
using iterator = typename list_type::iterator;
using const_iterator = typename list_type::const_iterator;
using map_type = M<K, iterator>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {
if (capacity == 0) {
throw std::runtime_error("LRUCache requires capacity > 0.");
}
}
size_t size() const {
return map_.size();
}
size_t capacity() const {
return capacity_;
}
bool empty() const {
return vlist_.empty();
}
void resize(size_t new_capacity) {
capacity_ = new_capacity;
trim();
}
iterator begin() {
return vlist_.begin();
}
const_iterator begin() const {
return vlist_.begin();
}
iterator end() {
return vlist_.end();
}
const_iterator end() const {
return vlist_.end();
}
void clear() {
map_.clear();
vlist_.clear();
}
iterator find(const K& key) {
auto it = map_.find(key);
if (it == map_.end())
return end();
vlist_.splice(vlist_.begin(), vlist_, it->second);
return it->second;
}
template <typename U>
std::pair<iterator, bool> emplace(const K& key, U&& value) {
auto it = map_.find(key);
if (it != map_.end()) {
vlist_.splice(vlist_.begin(), vlist_, it->second);
return {it->second, false};
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
trim();
return {vlist_.begin(), true};
}
iterator erase(iterator pos) {
map_.erase(pos->first);
return vlist_.erase(pos);
}
V& operator[](const K& key) {
auto it = find(key);
if (it == end()) {
it = emplace(key, V{}).first;
}
return it->second;
}
private:
void trim() {
while (map_.size() > capacity_) {
auto last = std::prev(vlist_.end());
map_.erase(last->first);
vlist_.pop_back();
}
}
list_type vlist_;
map_type map_;
size_t capacity_;
};
// Turn a POD struct into a container key by doing bytes compare.
template <typename T>
struct BytesKey {
T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {}
BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
BytesKey(BytesKey&& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
bool operator==(const BytesKey& other) const {
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
}
};
// Compute hash according to the bytes value of T.
template <typename T>
struct BytesHash {
static_assert(std::is_standard_layout_v<T>, "T is not POD");
size_t operator()(const T& pod) const {
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < sizeof(T); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return value;
}
};
template <typename K, typename V>
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
template <typename K, typename V>
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
} // namespace mlx::core

View File

@@ -2,15 +2,289 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
class MatMul {
public:
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()), pref_(cublas_preference(device)) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: MatMul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
~MatMul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void run(
cu::CommandEncoder& encoder,
void* out,
void* a,
void* b,
void* c = nullptr,
float alpha = 1,
float beta = 0) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order =
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace cu
namespace {
std::tuple<bool, int64_t, array>
@@ -79,25 +353,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
cu::MatMul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -112,13 +371,27 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc);
a_it.step();
b_it.step();
}
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -186,7 +459,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
cu::MatMul matmul(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -203,22 +476,41 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(),
c_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
matmul.run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha_,
beta_);
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha_,
beta_);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core

View File

@@ -2,17 +2,10 @@
#pragma once
#include "mlx/backend/cuda/steel/defines.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
#include "mlx/backend/cuda/matmul/tiles.cuh"
namespace mlx::core::cu {
/**
* Fallback mma.
*
* We should probably a) implement a fallback or complain about it to the
* compiler.
*/
template <typename U, typename T>
__device__ inline void
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
@@ -23,11 +16,10 @@ mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
*
* We actually perform C += A @ B.T
*/
__device__ __forceinline__ void mma_t(
__device__ inline void mma_t(
Tile16x16<float>& C,
Tile16x16<__nv_bfloat16>& A,
Tile16x16<__nv_bfloat16>& B) {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3}, "
@@ -84,14 +76,13 @@ __device__ __forceinline__ void mma_t(
"f"(C.values[2].y),
"f"(C.values[3].x),
"f"(C.values[3].y));
#endif
}
/**
* Multiply larger register tiles by delegating to mma_t.
*/
template <typename U, typename T, int M, int N, int K>
__device__ __forceinline__ void mma_t(
__device__ inline void mma_t(
RegisterTile<U, M, N>& C,
RegisterTile<T, M, K>& A,
RegisterTile<T, N, K>& B) {

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/backend/cuda/steel/utils.cuh"
#define MLX_UNROLL _Pragma("unroll")
namespace mlx::core::cu {
@@ -67,7 +67,7 @@ struct Tile16x16 {
* See
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
*/
__device__ __forceinline__ void load(uint32_t row_address) {
__device__ inline void load(uint32_t row_address) {
if constexpr (
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
asm volatile(
@@ -169,7 +169,7 @@ struct RegisterTile {
}
template <typename Tile>
__device__ __forceinline__ void
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
@@ -181,22 +181,6 @@ struct RegisterTile {
}
}
template <typename Tile, typename F>
__device__ __forceinline__ void
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
f(data[i * TILES_X + j],
tile,
base_address,
row + i * 16,
col + j * 16);
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
@@ -223,57 +207,6 @@ struct RegisterTile {
}
};
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
};
template <typename T, int ROWS_, int COLS_>
struct SharedTile {
static constexpr int ROWS = ROWS_;
@@ -282,8 +215,7 @@ struct SharedTile {
static constexpr int TILES_Y = ROWS / 16;
static constexpr int NUMEL = ROWS * COLS;
// Swizzle taken from ThunderKittens. Should be changed when we switch to
// cute Layouts.
// Swizzle taken from ThunderKittens.
//
// See inludes/types/shared/st.cuh
//
@@ -296,10 +228,6 @@ struct SharedTile {
T data[ROWS * COLS];
__device__ inline uint32_t base_addr() const {
return __cvta_generic_to_shared(&data[0]);
}
// Return a pointer to the element at (row, col) using the swizzle.
__device__ static inline T* ptr(T* ptr, int row, int col) {
if constexpr (swizzle_bytes > 0) {
@@ -367,8 +295,6 @@ struct SharedTile {
/**
* Load the tile from global memory by loading 16 bytes at a time and storing
* them immediately.
*
* Can also be used as a fallback for architectures before sm_80.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load(Tile& tile, const T* x, int N) {
@@ -392,6 +318,33 @@ __device__ inline void load(Tile& tile, const T* x, int N) {
}
}
/**
* Copy 16 bytes from the globale memory address pointed to by x to the smem
* address pointed to by row_address.
*
* A simple wrapper over the PTX.
*/
template <typename T>
__device__ inline void cp_async_16(uint32_t row_address, const T* x) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
/**
* Wait for all the async copies to finish.
*/
__device__ inline void cp_async_wait_all() {
asm volatile("cp.async.wait_all;\n" ::);
}
/**
* The asynchronous equivalent of load.
*
@@ -425,17 +378,12 @@ load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
cp_async<16>(
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
}
}
/**
* Same as load_async but checks if we can load the row.
*
* NOTE: It should be changed to use a predicated cp async instead.
*/
template <int NUM_WARPS, typename T, typename Tile>
__device__ inline void load_async_safe(
Tile& tile,
@@ -458,7 +406,7 @@ __device__ inline void load_async_safe(
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
if (row + i * STEP_ROWS < max_rows) {
cp_async<16>(
cp_async_16(
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
x + i * STEP_ROWS * N);
} else {

View File

@@ -1,54 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/primitives.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -0,0 +1,103 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/arange.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -256,7 +256,7 @@ void affine_quantize(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
@@ -312,7 +312,7 @@ void affine_dequantize(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,

View File

@@ -0,0 +1,228 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/matmul/mma.cuh"
#include "mlx/backend/cuda/matmul/tiles.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
namespace mlx::core {
namespace cu {
template <int NUM_WARPS, int group_size, int bits, typename T, typename Tile>
__device__ inline void load_quantized(
Tile& tile,
const uint8_t* x,
const T* scales,
const T* biases,
int N) {
constexpr int NUM_THREADS = NUM_WARPS * 32;
constexpr int ELEMENTS_PER_LOAD = sizeof(uint32_t) * get_pack_factor<bits>();
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
constexpr int MASK = (1 << bits) - 1;
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
const int Nx = N / get_pack_factor<bits>();
const int Ng = N / group_size;
x += row * Nx + col * (ELEMENTS_PER_LOAD / get_pack_factor<bits>());
scales += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
biases += row * Ng + col * ELEMENTS_PER_LOAD / group_size;
MLX_UNROLL
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
T vs[ELEMENTS_PER_LOAD];
uint32_t w = *reinterpret_cast<const uint32_t*>(x + i * STEP_ROWS * Nx);
T s = scales[i * STEP_ROWS * Ng];
T b = biases[i * STEP_ROWS * Ng];
MLX_UNROLL
for (int j = 0; j < ELEMENTS_PER_LOAD; j++) {
vs[j] = static_cast<T>((w >> (j * bits)) & MASK) * s + b;
}
tile.store(vs, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
}
}
template <
typename T,
int BM,
int BN,
int BK,
int group_size,
int bits,
bool aligned_M>
__global__ void qmm_t(
const T* x,
const uint8_t* w,
const T* scales,
const T* biases,
T* y,
int M,
int N,
int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 4;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&xs)[1] = *(SharedTile<T, BM, BK>(*)[1])(&shmem[0]);
SharedTile<T, BN, BK>(&ws)[1] =
*(SharedTile<T, BN, BK>(*)[1])(&shmem[1 * sizeof(T) * BM * BK]);
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
const int max_rows = M - blockIdx.y * BM;
x += blockIdx.y * BM * K;
w += blockIdx.x * BN * K / get_pack_factor<bits>();
scales += blockIdx.x * BN * K / group_size;
biases += blockIdx.x * BN * K / group_size;
y += blockIdx.y * BM * N + blockIdx.x * BN;
C.fill(0);
int tic = 0;
uint32_t base_addr_xs[1], base_addr_ws[1];
base_addr_xs[0] = __cvta_generic_to_shared(&xs[0].data[0]);
base_addr_ws[0] = __cvta_generic_to_shared(&ws[0].data[0]);
if (aligned_M || max_rows >= BM) {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(xs[tic], base_addr_xs[tic], x + k_block, K);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global(y, N, offset_m, offset_n);
} else {
for (int k_block = 0; k_block < K; k_block += BK) {
load_async_safe<NUM_WARPS>(
xs[tic], base_addr_xs[tic], x + k_block, K, max_rows);
cp_async_commit();
load_quantized<NUM_WARPS, group_size, bits>(
ws[tic],
w + k_block / get_pack_factor<bits>(),
scales + k_block / group_size,
biases + k_block / group_size,
K);
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
xs[tic],
base_addr_xs[tic],
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
ws[tic],
base_addr_ws[tic],
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
}
C.store_global_safe(y, N, offset_m, offset_n, max_rows);
}
}
} // namespace cu
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.dtype() != bfloat16) {
throw std::invalid_argument("[qmm] Only bfloat16 is supported for now");
}
if (!transpose_) {
throw std::invalid_argument(
"[qmm] Only transposed matmul is supported for now");
}
dispatch_float_types(x.dtype(), "qmm", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
auto kernel =
cu::qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, true>;
if (M % BM != 0) {
kernel = cu::
qmm_t<DataType, BM, BN, BK, group_size.value, bits.value, false>;
}
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
enc.add_kernel_node(
kernel,
grid,
2 * 4 * 32,
1 * sizeof(DataType) * (BM * BK + BN * BK),
x.data<DataType>(),
w.data<uint8_t>(),
scales.data<DataType>(),
biases.data<DataType>(),
out.data<DataType>(),
M,
N,
K);
});
});
});
}
} // namespace mlx::core

View File

@@ -1,10 +1,14 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
@@ -28,28 +32,57 @@ inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
} // namespace
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
out.set_data(allocator::malloc(out.nbytes()));
// Make sure the last two dims of x and w, s, b are contiguous. This should
// be relaxed for x.
array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
array biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
int K = x.shape(-1);
int M = non_batched ? x.size() / K : x.shape(-2);
int N = out.shape(-1);
qmm(x,
w,
scales,
biases,
out,
transpose_,
group_size_,
bits_,
M,
N,
K,
enc,
s);
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);

View File

@@ -24,4 +24,19 @@ void affine_dequantize(
cu::CommandEncoder& enc,
const Stream& s);
void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
array& out,
bool transpose_,
int group_size_,
int bits_,
int M,
int N,
int K,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -5,6 +5,8 @@
#include "mlx/backend/gpu/copy.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <cassert>

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
@@ -10,6 +11,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -55,7 +58,7 @@ __global__ void rms_norm(
const T* w,
T* out,
float eps,
uint32_t axis_size,
int32_t axis_size,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
@@ -70,8 +73,8 @@ __global__ void rms_norm(
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
@@ -83,14 +86,15 @@ __global__ void rms_norm(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
T xn[N_READS];
T wn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
float norm = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm);
}
store_vector<N_READS>(out, index, xn, axis_size);
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
@@ -122,10 +126,13 @@ __global__ void rms_norm_vjp(
// Normalizer.
float2 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
@@ -142,9 +149,12 @@ __global__ void rms_norm_vjp(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
@@ -154,9 +164,9 @@ __global__ void rms_norm_vjp(
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
@@ -214,9 +224,9 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(w);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr uint32_t N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
@@ -304,10 +314,11 @@ void RMSNormVJP::eval_gpu(
encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr int N_READS = 4;
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
auto kernel = cu::rms_norm_vjp<
DataType,
has_w_constant.value,

View File

@@ -1,781 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
#define PRAGMA_LOOP_UNROLL #pragma unroll
struct AttnParams {
int B;
int H;
int D;
int qL;
int kL;
int gqa_factor;
float scale;
int64_t Q_strides[3];
int64_t K_strides[3];
int64_t V_strides[3];
int64_t O_strides[3];
};
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_1pass(
const T* Q,
const T* K,
const T* V,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = BN * int(params.K_strides[2]);
const int inner_v_stride = BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -INFINITY;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = max_scores[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
float* partials,
float* sums,
float* maxs,
__grid_constant__ const AttnParams params) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
typedef float U;
U q[v_per_thread];
U k[v_per_thread];
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z / blocks;
const int block_idx = blockIdx.z % blocks;
const int head_idx = blockIdx.x;
const int kv_head_idx = head_idx / params.gqa_factor;
const int q_seq_idx = blockIdx.y;
const int kv_seq_idx = block_idx * BN + warp_idx;
Q += batch_idx * params.Q_strides[0] + // Batch
head_idx * params.Q_strides[1] + // Head
q_seq_idx * params.Q_strides[2]; // Sequence
K += batch_idx * params.K_strides[0] + // Batch
kv_head_idx * params.K_strides[1] + // Head
kv_seq_idx * params.K_strides[2]; // Sequence
V += batch_idx * params.V_strides[0] + // Batch
kv_head_idx * params.V_strides[1] + // Head
kv_seq_idx * params.V_strides[2]; // Sequence
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s + // Sequence
block_idx; // Block
partials += p_offset * D;
sums += p_offset;
maxs += p_offset;
// Read the query and 0 the output accumulator
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
}
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0.f;
}
U max_score = -1e9;
U sum_exp_score = 0.f;
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
bool use_key = true;
if constexpr (do_causal) {
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
// Compute the i-th score
U score = 0.f;
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
score += q[j] * k[j];
}
// Warp sum
score = cg::reduce(warp, score, cg::plus<U>());
// Update the accumulators
U new_max = max(max_score, score);
U factor = exp2f(max_score - new_max);
U exp_score = exp2f(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
// Move the pointers to the next kv
K += inner_k_stride;
V += inner_v_stride;
}
if (lane_idx == 0) {
max_scores[warp_idx] = max_score;
sum_exp_scores[warp_idx] = sum_exp_score;
}
block.sync();
max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
// Write the sum and new max
if (warp_idx == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
auto ff = exp2f(max_scores[warp_idx] - new_max);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[warp_idx][lane_idx] = o[i] * ff;
block.sync();
if (warp_idx == 0) {
U ot = outputs[0][lane_idx];
PRAGMA_LOOP_UNROLL
for (int j = 1; j < BN; j++) {
ot += outputs[j][lane_idx];
warp.sync();
}
o[i] = ot;
}
block.sync();
}
if (warp_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = o[i];
}
}
}
template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_2pass_2(
const float* partials,
const float* sums,
const float* maxs,
T* O,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int blocks = 32;
constexpr int v_per_thread = D / BD;
typedef float U;
U o[v_per_thread];
__shared__ U outputs[BN][BD + 1];
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
const int lane_idx = warp.thread_rank();
const int warp_idx = warp.meta_group_rank();
// Adjust to thread block and thread
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.x;
const int q_seq_idx = blockIdx.y;
const int p_stride_s = blocks;
const int p_stride_h = params.qL * p_stride_s;
const int p_stride_b = params.H * p_stride_h;
const int p_offset = batch_idx * p_stride_b + // Batch
head_idx * p_stride_h + // Head
q_seq_idx * p_stride_s; // Sequence
partials += p_offset * D + warp_idx * D;
sums += p_offset;
maxs += p_offset;
O += batch_idx * params.O_strides[0] + // Batch
head_idx * params.O_strides[1] + // Head
q_seq_idx * params.O_strides[2]; // Sequence
U max_score = maxs[lane_idx];
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
o[i] = partials[v_per_thread * lane_idx + i];
}
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
outputs[lane_idx][warp_idx] = o[i];
block.sync();
U ot = outputs[warp_idx][lane_idx] * factor;
o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
block.sync();
}
// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
}
}
} // namespace cu
namespace {
template <typename F>
void dispatch_headdim(int n, F&& f) {
switch (n) {
case 64:
f(std::integral_constant<int, 64>{});
break;
case 96:
f(std::integral_constant<int, 96>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
void sdpa_vector_1pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel =
cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
params);
});
});
});
}
void sdpa_vector_2pass_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
/* int D = */ q.shape(3),
/* int qL = */ q.shape(2),
/* int kL = */ k.shape(2),
/* int gqa_factor = */ q.shape(1) / k.shape(1),
/* float scale = */ scale,
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
// Allocate the intermediates
int blocks = 32;
Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
intermediate_shape.insert(
intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
intermediate_shape.push_back(blocks);
intermediate_shape.push_back(o.shape().back());
array intermediate(intermediate_shape, float32, nullptr, {});
intermediate_shape.pop_back();
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
sums.set_data(allocator::malloc(sums.nbytes()));
maxs.set_data(allocator::malloc(maxs.nbytes()));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
{
auto kernel = cu::
kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
dim3 grid_dim(params.H, params.qL, params.B * 32);
dim3 block_dim(8 * 32, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
params);
}
{
auto kernel = cu::
kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
encoder.set_input_array(intermediate);
encoder.set_input_array(sums);
encoder.set_input_array(maxs);
encoder.set_output_array(o);
dim3 grid_dim(params.H, params.qL, params.B);
dim3 block_dim(1024, 1, 1);
encoder.add_kernel_node(
kernel,
grid_dim,
block_dim,
0,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
o.data<DataType>(),
params);
}
});
});
});
}
void sdpa_vector_fallback(
const Stream& s,
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
}
}
} // namespace
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
return has_arr_mask || !supported_config;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
array arr_copy = contiguous_copy_gpu(arr, s);
copies.push_back(std::move(arr_copy));
return copies.back();
} else {
return arr;
}
};
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
if (arr.flags().row_contiguous) {
return true;
}
auto& strides = arr.strides();
auto& shape = arr.shape();
if (shape[0] == 1 || shape[1] == 1) {
// If either the batch or head dimension is a singleton, the other can
// be transposed with the sequence dimension
auto bidx = shape[0] == 1 ? 1 : 0;
return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
(strides[bidx] == shape[3]);
}
return false;
};
auto kv_copy_unless = [](const array& arr) {
// keys and values should be copied if:
// - the last dimension is not contiguous
// - the batch and head dim are not contiguous
auto& strides = arr.strides();
auto& shape = arr.shape();
if (strides.back() != 1) {
return false;
}
if (shape[0] == 1 || shape[1] == 1) {
return true;
}
return (strides[0] == strides[1] * shape[1]);
};
const auto& q = copy_unless(q_copy_unless, q_pre);
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
} else {
int64_t str_oD = 1;
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
}
// Full attention mode should never reach here
else {
throw std::runtime_error("Doesn't support matrix yet.");
}
}
} // namespace fast
} // namespace mlx::core

View File

@@ -11,6 +11,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cassert>
@@ -44,21 +45,20 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = cast_to<AccT>(0);
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
#pragma unroll
for (int i = 0; i < N_READS; i++) {
normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
@@ -95,11 +95,12 @@ __global__ void softmax(const T* in, T* out, int axis_size) {
// Write output.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto vals = load_vector<N_READS>(in, index, axis_size, T(0));
T vals[N_READS];
cub::LoadDirectBlocked(index, in, vals, axis_size);
for (int i = 0; i < N_READS; i++) {
vals[i] = softmax_exp(static_cast<AccT>(vals[i]) - maxval) * normalizer;
}
store_vector<N_READS>(out, index, vals, axis_size);
cub::StoreDirectBlocked(index, out, vals, axis_size);
}
}
@@ -140,9 +141,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
constexpr int N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
if (precise) {
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;

View File

@@ -13,6 +13,7 @@
#include <cub/device/device_segmented_sort.cuh>
#include <cassert>
#include <numeric>
namespace mlx::core {
@@ -26,6 +27,29 @@ struct ModOp {
}
};
// We can not use any op in eval, make an utility.
array swapaxes_in_eval(const array& in, int axis1, int axis2) {
std::vector<int> axes(in.ndim());
std::iota(axes.begin(), axes.end(), 0);
std::swap(axes[axis1], axes[axis2]);
// TODO: Share the code with Transpose::eval.
Shape shape(axes.size());
Strides strides(in.ndim());
for (size_t ax = 0; ax < axes.size(); ++ax) {
shape[ax] = in.shape()[axes[ax]];
strides[ax] = in.strides()[axes[ax]];
}
auto flags = in.flags();
if (flags.contiguous) {
auto [_, row_contiguous, col_contiguous] = check_contiguity(shape, strides);
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
}
array out(shape, in.dtype(), nullptr, {});
out.copy_shared_buffer(in, strides, flags, in.data_size());
return out;
}
struct OffsetTransform {
int nsort;

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#define MLX_UNROLL _Pragma("unroll")
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#define MLX_CUDA_SM_80_ENABLED
#endif

View File

@@ -1,101 +0,0 @@
#include "mlx/backend/cuda/steel/mma.cuh"
#include "mlx/backend/cuda/steel/tiles.cuh"
namespace mlx::core::cu {
/**
* An example gemm written with the utils.
*
* Computes A @ B.T when A and B are all aligned with the block sizes.
*/
template <typename T, int BM, int BN, int BK>
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
constexpr int WARPS_M = 2;
constexpr int WARPS_N = 2;
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
constexpr int WARP_STEP_M = BM / WARPS_M;
constexpr int WARP_STEP_N = BN / WARPS_N;
// Precompute some offsets for each thread
const int warpid = threadIdx.x / 32;
const int laneid = threadIdx.x % 32;
const int wm = warpid / WARPS_N;
const int wn = warpid % WARPS_N;
const int offset_m = wm * WARP_STEP_M;
const int offset_n = wn * WARP_STEP_N;
// Allocate shared memory
extern __shared__ char shmem[];
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
SharedTile<T, BN, BK>(&bs)[2] =
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
// Allocate registers for the MMA
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
RegisterTile<T, BM / WARPS_M, 16> A;
RegisterTile<T, BN / WARPS_N, 16> B;
// Move the global pointers to the tile
a += blockIdx.y * BM * K;
b += blockIdx.x * BN * K;
y += blockIdx.y * BM * N + blockIdx.x * BN;
// Zero the accumulators
C.fill(0);
// Start the SM pipeline
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
cp_async_commit();
int tic = 0;
for (int k_block = BK; k_block < K; k_block += BK) {
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
cp_async_commit();
cp_async_wait<1>();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
tic ^= 1;
}
// Empty the pipeline
cp_async_wait_all();
__syncthreads();
MLX_UNROLL
for (int k = 0; k < BK / 16; k++) {
A.load(
as[tic],
as[tic].base_addr(),
offset_m + laneid % 16,
k * 16 + laneid / 16 * 8);
B.load(
bs[tic],
bs[tic].base_addr(),
offset_n + laneid % 16,
k * 16 + laneid / 16 * 8);
mma_t(C, A, B);
}
C.store_global(y, N, offset_m, offset_n);
}
} // namespace mlx::core::cu

View File

@@ -1,89 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/steel/defines.cuh"
namespace mlx::core::cu {
/**
* Copy bytes from the global memory address pointed to by x to the smem
* address pointed to by row_address.
*
* A simple wrapper over the PTX.
*/
template <int N, typename T>
__device__ inline void cp_async(uint32_t row_address, const T* x) {
static_assert(
N == 16 || N == 8 || N == 4,
"cp.async is only supported for N in {4, 8, 16}.");
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 16) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int4*>(x)));
} else if constexpr (N == 8) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int2*>(x)));
} else if constexpr (N == 4) {
asm volatile(
"cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address),
"l"(reinterpret_cast<const int*>(x)));
}
#endif
}
/**
* Submit all the previous async copies to be executed.
*/
__device__ inline void cp_async_commit() {
#if defined(MLX_CUDA_SM_80_ENABLED)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
/**
* Wait for all but N of the async copies to finish.
*/
template <int N>
__device__ inline void cp_async_wait() {
#if defined(MLX_CUDA_SM_80_ENABLED)
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
}
#endif
}
/**
* Wait for all the async copies to finish.
*/
__device__ inline void cp_async_wait_all() {
cp_async_wait<0>();
}
/**
* Extract ``bits`` bits from the 32 bit value.
*
* Single instruction shift and mask.
*/
template <int bits>
__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"extract_bits only supports 2, 4, 8 for now.");
uint32_t result;
if constexpr (bits == 2) {
asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit));
} else if constexpr (bits == 4) {
asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit));
} else if constexpr (bits == 8) {
asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit));
}
return result;
}
} // namespace mlx::core::cu

View File

@@ -32,7 +32,7 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) {
AlignedVector<T, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]);
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i], c_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -76,7 +76,7 @@ __global__ void ternary_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx, c_idx] = elem_to_loc(
auto [a_idx, b_idx, c_idx] = elem_to_loc_4d(
index,
shape.data(),
a_strides.data(),
@@ -125,9 +125,12 @@ void ternary_op_gpu_inplace(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel =
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
kernel,
num_blocks,
block_dims,
0,
@@ -142,9 +145,11 @@ void ternary_op_gpu_inplace(
const_param<dims_constant()>(c_strides));
});
} else {
auto [num_blocks, block_dims] = get_launch_args(out, large());
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
cu::ternary_g<Op, DType, IdxT>,
kernel,
num_blocks,
block_dims,
0,
@@ -163,11 +168,18 @@ void ternary_op_gpu_inplace(
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::ternary_v<Op, DType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
kernel,
num_blocks,
block_dims,
0,
@@ -195,7 +207,7 @@ void ternary_op_gpu(
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Select::eval_gpu");
nvtx3::scoped_range r("select::eval_gpu");
auto& s = out.primitive().stream();
ternary_op_gpu<cu::Select>(inputs, out, s);
}

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -30,7 +31,7 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
out_vec.val[i] = Op{}(in_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -47,7 +48,7 @@ __global__ void unary_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim);
auto idx = elem_to_loc_4d(index, shape.data(), strides.data(), ndim);
out[index] = Op{}(in[idx]);
}
}
@@ -129,10 +130,16 @@ void unary_op_gpu_inplace(
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::unary_v<Op, InType, OutType, IdxT, N_READS>;
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
kernel,
out.data_size(),
out.shape(),
out.strides(),
large,
N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
kernel,
num_blocks,
block_dims,
0,
@@ -142,9 +149,10 @@ void unary_op_gpu_inplace(
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto [num_blocks, block_dims] = get_launch_args(out, large);
auto kernel = cu::unary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(
cu::unary_g<Op, InType, OutType, IdxT>,
kernel,
num_blocks,
block_dims,
0,

View File

@@ -17,35 +17,6 @@ CudaStream::~CudaStream() {
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
}
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
other.handle_ = nullptr;
};
CudaGraphExec::~CudaGraphExec() {
reset();
}
void CudaGraphExec::instantiate(cudaGraph_t graph) {
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
}
void CudaGraphExec::reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
handle_ = nullptr;
}
}
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
void check_cuda_error(const char* name, cudaError_t err) {
if (err != cudaSuccess) {
throw std::runtime_error(

View File

@@ -4,7 +4,6 @@
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
@@ -33,34 +32,11 @@ class CudaStream {
cudaStream_t stream_;
};
// Move-able RAII handle of cudaGraphExec_t.
class CudaGraphExec {
public:
CudaGraphExec(cudaGraphExec_t handle = nullptr);
CudaGraphExec(CudaGraphExec&& other);
~CudaGraphExec();
CudaGraphExec(const CudaGraphExec&) = delete;
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
void instantiate(cudaGraph_t graph);
void reset();
operator cudaGraphExec_t() const {
return handle_;
}
private:
cudaGraphExec_t handle_;
};
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Convert Dtype to CUDA C++ types.

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
@@ -11,10 +12,10 @@ Worker::Worker()
Worker::~Worker() {
{
std::lock_guard lock(mtx_);
std::lock_guard lock(worker_mutex_);
stop_ = true;
}
cond_.notify_one();
worker_event_.signal(batch_ + 1);
worker_.join();
}
@@ -22,41 +23,53 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task));
}
void Worker::signal(void* data) {
auto w = static_cast<Worker*>(data);
{
std::lock_guard lock(w->mtx_);
w->signaled_batch_++;
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
w->cond_.notify_one();
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{
std::lock_guard lock(worker_mutex_);
worker_tasks_[batch_] = std::move(pending_tasks_);
}
uncommited_batches_++;
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) {
// Move pending tasks into tasks
if (pending_tasks_.empty()) {
if (uncommited_batches_ == 0) {
return;
}
{
std::lock_guard lock(mtx_);
// Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
uncommited_batches_ = 0;
// Signal the |worker_event_| in |signal_stream_| after the kernels in
// |stream_| finish running.
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
cudaLaunchHostFunc(signal_stream_, signal, this);
worker_event_.signal(signal_stream_, batch_);
}
void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) {
uint64_t current_batch = 0;
uint64_t batch = worker_event_.value();
Tasks tasks;
{
std::unique_lock<std::mutex> lk(mtx_);
cond_.wait(lk, [this, &current_batch] {
return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
std::lock_guard lock(worker_mutex_);
// Move tasks in signaled batches.
auto end = worker_tasks_.upper_bound(batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) {
tasks = std::move(it->second);
@@ -72,6 +85,7 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]);
task();
}
worker_event_.wait(batch + 1);
}
}

View File

@@ -5,7 +5,6 @@
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional>
#include <map>
#include <mutex>
@@ -25,24 +24,38 @@ class Worker {
// Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream|
// finish running.
void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private:
static void signal(void*);
void thread_fn();
std::mutex mtx_;
std::condition_variable cond_;
uint64_t committed_batch_{0};
uint64_t signaled_batch_{0};
uint64_t batch_{0};
size_t uncommited_batches_{0};
// Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_;
CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to
@@ -50,7 +63,6 @@ class Worker {
using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_;
std::thread worker_;
};
} // namespace mlx::core::cu

View File

@@ -133,7 +133,6 @@ void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Pad::eval_gpu");
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];

View File

@@ -128,7 +128,8 @@ Buffer MetalAllocator::malloc(size_t size) {
auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure try to reclaim memory from the cache
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);

View File

@@ -60,16 +60,6 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
@@ -104,7 +94,7 @@ struct CommandEncoder {
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*>& outputs() {
std::unordered_set<const void*> outputs() {
return all_outputs_;
};

View File

@@ -14,10 +14,6 @@ Event::Event(Stream stream) : stream_(stream) {
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
if (event_ == nullptr) {
throw std::runtime_error(
"[Event::Event] Failed to create Metal shared event.");
}
}
void Event::wait() {

View File

@@ -32,20 +32,15 @@ inline array ensure_row_contiguous_matrix(
const array& x,
metal::Device& d,
const Stream& s) {
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
} else {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
@@ -270,15 +265,9 @@ void qvm_split_k(
MTL::Size group_dims = MTL::Size(bk, 2, 1);
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
if (x_shape.size() == 1) {
x_shape.insert(x_shape.begin(), 1);
x_strides.insert(x_strides.begin(), 0);
}
int x_ndim = x_shape.size();
int x_batch_ndims = x_ndim - 2;
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
@@ -289,7 +278,7 @@ void qvm_split_k(
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x_ndim - 1] = split_D;
x_strides[x.ndim() - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
@@ -302,9 +291,6 @@ void qvm_split_k(
int final_block_size = K - (split_k - 1) * split_D;
auto temp_shape = out.shape();
if (temp_shape.size() == 1) {
temp_shape.insert(temp_shape.begin(), 1);
}
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));

View File

@@ -6,4 +6,3 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@@ -2,11 +2,9 @@
#include <unordered_map>
#include <iostream>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed {
@@ -82,7 +80,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available() || nccl::is_available();
return mpi::is_available() || ring::is_available();
}
int Group::rank() const {
@@ -113,8 +111,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";

View File

@@ -3,6 +3,7 @@
#pragma once
#include <memory>
#include "mlx/array.h"
namespace mlx::core::distributed {

View File

@@ -1,8 +0,0 @@
if(MLX_BUILD_CUDA)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
find_package(NCCL REQUIRED)
target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
endif()

View File

@@ -1,359 +0,0 @@
#include <arpa/inet.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
namespace mlx::core::distributed::nccl {
#define CHECK_CUDA(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
fprintf( \
stderr, \
"CUDA error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(e)); \
exit(1); \
} \
} while (0)
#define CHECK_NCCL(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
fprintf( \
stderr, \
"NCCL error %s:%d '%s'\n", \
__FILE__, \
__LINE__, \
ncclGetErrorString(r)); \
exit(1); \
} \
} while (0)
namespace detail {
inline void sendAll(int sock, const void* buf, size_t len) {
const char* ptr = reinterpret_cast<const char*>(buf);
while (len > 0) {
ssize_t sent = send(sock, ptr, len, 0);
if (sent <= 0) {
perror("send");
exit(1);
}
ptr += sent;
len -= sent;
}
}
inline void recvAll(int sock, void* buf, size_t len) {
char* ptr = reinterpret_cast<char*>(buf);
while (len > 0) {
ssize_t rec = recv(sock, ptr, len, 0);
if (rec <= 0) {
perror("recv");
exit(1);
}
ptr += rec;
len -= rec;
}
}
inline void bootstrap_unique_id(
ncclUniqueId& id,
int rank,
int size,
const std::string& initMethod) {
// Parse the init method to extract the host and port
if (initMethod.rfind("tcp://", 0) != 0)
throw;
auto hostport = initMethod.substr(6);
auto colon = hostport.find(':');
std::string host = hostport.substr(0, colon);
int port = std::stoi(hostport.substr(colon + 1));
if (rank == 0) {
// create a unique id on the rank 0
CHECK_NCCL(ncclGetUniqueId(&id));
// create a socket to send the unique id to all other ranks
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] Couldn't create socket (error: " << errno << ")";
throw std::runtime_error(msg.str());
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
serv.sin_addr.s_addr = htonl(INADDR_ANY);
serv.sin_port = htons(port);
int reuse = 1;
// Without this, if rank-0 crashes or restarts process quickly,
// the OS might refuse to let binding to the same port, so reuse
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) {
std::ostringstream msg;
msg << "[nccl] setsockopt() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (bind(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) < 0) {
std::ostringstream msg;
msg << "[nccl] bind() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
if (listen(sock, size - 1) < 0) {
std::ostringstream msg;
msg << "[nccl] listen() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
for (int peer = 1; peer < size; ++peer) {
int conn = accept(sock, nullptr, nullptr);
if (conn < 0) {
std::ostringstream msg;
msg << "[nccl] accept() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
sendAll(conn, &id, sizeof(id));
close(conn);
}
close(sock);
} else {
// Here just wanted to make show that rank 0 has enough time to bind
// so we will retry to connect until max attempts
int sock = socket(AF_INET, SOCK_STREAM, 0);
if (sock < 0) {
std::ostringstream msg;
msg << "[nccl] socket() failed: " << strerror(errno);
throw std::runtime_error(msg.str());
}
hostent* he = gethostbyname(host.c_str());
if (!he) {
throw std::runtime_error("[nccl] lookup failed for host: " + host);
}
sockaddr_in serv = {};
serv.sin_family = AF_INET;
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
serv.sin_port = htons(port);
const int max_retries = 30;
int attempt = 0;
bool connected = false;
for (attempt = 0; attempt < max_retries; ++attempt) {
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
0) {
connected = true;
std::cout << "[Rank " << rank << "] Connected successfully on attempt "
<< attempt + 1 << std::endl;
break;
}
if (errno != ECONNREFUSED) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!connected) {
std::ostringstream msg;
msg << "[Rank " << rank << "] connect() failed after " << attempt
<< " retries: " << strerror(errno);
close(sock);
throw std::runtime_error(msg.str());
}
recvAll(sock, &id, sizeof(id));
close(sock);
}
}
template <typename T>
struct type_identity {
using type = T;
};
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
switch (arr.dtype()) {
case bool_:
throw std::invalid_argument("[nccl] Boolean arrays not supported");
case int8:
f(type_identity<int8_t>{}, ncclChar);
break;
case uint8:
f(type_identity<uint8_t>{}, ncclUint8);
break;
case int32:
f(type_identity<int32_t>{}, ncclInt);
break;
case uint32:
f(type_identity<uint32_t>{}, ncclUint32);
break;
case int64:
f(type_identity<int64_t>{}, ncclInt64);
break;
case uint64:
f(type_identity<uint64_t>{}, ncclUint64);
break;
case float16:
f(type_identity<float16_t>{}, ncclHalf);
break;
case bfloat16:
f(type_identity<bfloat16_t>{}, ncclBfloat16);
break;
case float32:
f(type_identity<float>{}, ncclFloat);
break;
case float64:
f(type_identity<double>{}, ncclDouble);
break;
default:
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
}
} // namespace detail
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
class NCCLGroup : public GroupImpl {
public:
NCCLGroup(int worldRank, int worldSize, const std::string initMethod)
: rank_(worldRank),
size_(worldSize),
comm_(nullptr),
initMethod_(initMethod) {
if (initialized_)
return;
int ndev;
CHECK_CUDA(cudaGetDeviceCount(&ndev));
CHECK_CUDA(cudaSetDevice(rank_ % ndev));
detail::bootstrap_unique_id(uniqueId_, rank_, size_, initMethod_);
CHECK_NCCL(ncclCommInitRank(&comm_, size_, uniqueId_, rank_));
initialized_ = true;
}
~NCCLGroup() {
ncclCommDestroy(comm_);
ncclGroupEnd();
initialized_ = false;
}
int rank() override {
return rank_;
}
int size() override {
return size_;
}
void all_sum(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[nccl] Group split not supported.");
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
}
void send(const array& input, int dst, Stream stream) override {
throw std::runtime_error("[nccl] Send not supported in NCCL backend.");
}
void recv(array& output, int src, Stream stream) override {
throw std::runtime_error("[nccl] Recv not supported in NCCL backend.");
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
}
template <typename T>
void all_reduce_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;
ncclComm_t comm_;
bool initialized_ = false;
};
bool is_available() {
return true;
}
namespace detail {
static std::string get_env_var_or_throw(const char* env_var_name) {
const char* value = std::getenv(env_var_name);
if (value == nullptr) {
std::ostringstream msg;
msg << "[nccl] Required environment variable '" << env_var_name
<< "' is not set. "
<< "Please set it before initializing the distributed backend.";
throw std::runtime_error(msg.str());
}
return std::string(value);
}
} // namespace detail
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
std::string host = detail::get_env_var_or_throw("NCCL_HOST_IP");
std::string port = detail::get_env_var_or_throw("NCCL_PORT");
std::string rank_str = detail::get_env_var_or_throw("MLX_RANK");
std::string n_nodes_str = detail::get_env_var_or_throw("MLX_WORLD_SIZE");
int rank = std::stoi(rank_str);
int n_nodes = std::stoi(n_nodes_str);
std::string init_method = "tcp://" + host + ":" + port;
return std::make_shared<NCCLGroup>(rank, n_nodes, init_method);
}
} // namespace mlx::core::distributed::nccl

View File

@@ -1,12 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available();
std::shared_ptr<GroupImpl> init(bool strict = false);
} // namespace mlx::core::distributed::nccl

View File

@@ -1,20 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/nccl/nccl.h"
namespace mlx::core::distributed::nccl {
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
bool is_available() {
return false;
}
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
if (strict) {
throw std::runtime_error("Cannot initialize nccl distributed backend.");
}
return nullptr;
}
} // namespace mlx::core::distributed::nccl

View File

@@ -31,7 +31,8 @@ array all_sum(
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Sum),
{x});
}

View File

@@ -975,6 +975,7 @@ class RingGroup : public GroupImpl {
int rank_;
int size_;
bool verbose_;
ThreadPool pool_;

Some files were not shown because too many files have changed in this diff Show More