Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
870208eff5 Start sdpa vector 2025-06-16 17:38:39 -07:00
230 changed files with 4346 additions and 12201 deletions

View File

@@ -7,9 +7,15 @@ parameters:
nightly_build:
type: boolean
default: false
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
@@ -32,7 +38,7 @@ jobs:
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
@@ -64,9 +70,9 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
@@ -78,36 +84,37 @@ jobs:
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
@@ -132,49 +139,51 @@ jobs:
- run:
name: Install dependencies
command: |
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
command: |
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
source .venv/bin/activate
source env/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -183,7 +192,7 @@ jobs:
- run:
name: Build small binary
command: |
source .venv/bin/activate
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -195,60 +204,37 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e .
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
python -m venv env
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source .venv/bin/activate
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
@@ -290,29 +276,21 @@ jobs:
command: |
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
@@ -329,100 +307,52 @@ jobs:
python_version:
type: string
default: "3.9"
build_env:
extra_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.9", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Build wheel
name: Upload package
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -434,6 +364,7 @@ workflows:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
@@ -441,16 +372,14 @@ workflows:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- cuda_build_and_test
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
@@ -533,25 +462,6 @@ workflows:
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -572,9 +482,6 @@ workflows:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -633,17 +540,11 @@ workflows:
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
- build_cuda_release
build_dev_release:
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.weekly_build >>
jobs:
- build_release:
matrix:
@@ -713,12 +614,14 @@ workflows:
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -19,7 +19,6 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -41,9 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -66,17 +64,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
endif()
else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
# ----------------------------- Lib -----------------------------
@@ -243,16 +234,12 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
@@ -68,23 +68,18 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
```bash
**With `pip`**:
```
pip install mlx
```
To install the CUDA backend on Linux, run:
**With `conda`**:
```bash
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install mlx[cpu]
conda install -c conda-forge mlx
```
Checkout the

View File

@@ -192,22 +192,6 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -5,7 +5,6 @@ import os
import time
import torch
import torch.cuda
import torch.mps
@@ -45,10 +44,8 @@ def bench(f, *args):
def sync_if_needed(x):
if x.device == torch.device("mps"):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad()
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad()
def softmax(axis, x):
ys = []
@@ -351,11 +340,7 @@ if __name__ == "__main__":
args.axis.pop(0)
torch.set_num_threads(1)
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
device = "cpu" if args.cpu else "mps"
types = args.dtype
if not types:
@@ -475,8 +460,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -51,20 +51,6 @@ def time_maximum():
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -122,8 +108,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()
time_negative()

View File

@@ -138,13 +138,13 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx
To install from PyPI your system must meet the following requirements:
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
@@ -23,39 +23,12 @@ To install from PyPI your system must meet the following requirements:
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
CUDA
^^^^
MLX has a CUDA backend which you can install with:
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
conda install conda-forge::mlx
Troubleshooting
@@ -92,8 +65,6 @@ Build Requirements
Python API
^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -105,20 +76,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell
pip install .
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
pip install -e ".[dev]"
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with:
@@ -136,8 +107,6 @@ IDE:
C++ API
^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start
@@ -216,7 +185,6 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -245,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -19,4 +19,3 @@ Common Optimizers
Adamax
Lion
MultiOptimizer
Muon

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = tree_flatten(model.parameters(), destination={})
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream>
#include <sstream>
@@ -17,19 +16,6 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -181,15 +167,16 @@ void Axpby::eval_gpu(
}
// Resolve name of kernel (corresponds to axpby.metal)
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext", current_binary_dir());
auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname, lib);
auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** The name of primitive. */
const char* name() const override {
return "Axpby";
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.4.0
nanobind==2.2.0

View File

@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -10,7 +10,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -19,8 +18,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc

View File

@@ -14,8 +14,6 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
@@ -52,8 +50,6 @@ std::string get_type_string(Dtype d) {
return "float16_t";
case bfloat16:
return "bfloat16_t";
case float64:
return "double";
case complex64:
return "complex64_t";
case bool_:

View File

@@ -18,12 +18,8 @@ std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
if constexpr (std::is_same_v<T, double>) {
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>

View File

@@ -12,11 +12,16 @@ namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}};
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
@@ -37,11 +42,17 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}};
// Get and check the shape for the batched dims
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Shape B_bshape{b.shape().begin(), b.shape().end() - 2};
Shape C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
throw std::runtime_error(msg.str());
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};

View File

@@ -5,9 +5,11 @@
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core

View File

@@ -1,20 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
std::filesystem::path current_binary_dir() {
static std::filesystem::path binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path();
}();
return binary_dir;
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
@@ -205,15 +199,12 @@ Dims get_2d_grid_dims_common(
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
@@ -228,31 +219,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
int ndim = x.ndim();
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
auto shape = x.shape();
std::swap(shape[axis1], shape[axis2]);
auto strides = x.strides();
std::swap(strides[axis1], strides[axis2]);
auto [data_size, row_contiguous, col_contiguous] =
check_contiguity(shape, strides);
bool contiguous = data_size == x.data_size();
array out(std::move(shape), x.dtype(), nullptr, {});
out.copy_shared_buffer(
x,
std::move(strides),
{contiguous, row_contiguous, col_contiguous},
x.data_size());
return out;
}
} // namespace mlx::core

View File

@@ -2,7 +2,6 @@
#pragma once
#include <filesystem>
#include <tuple>
#include <vector>
@@ -10,8 +9,7 @@
namespace mlx::core {
// Return the directory that contains current shared library.
std::filesystem::path current_binary_dir();
std::string get_primitive_string(Primitive* primitive);
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
@@ -196,11 +194,8 @@ void shared_buffer_reshape(
const Strides& out_strides,
array& out);
// Like the swapaxes op but safe to call in eval_gpu.
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}

View File

@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// The decomposition is computed in place, so just copy the input to the
// output.
copy_cpu(
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -231,7 +231,7 @@ inline void build_kernel(
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
os << x.primitive().name();
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
@@ -288,14 +288,6 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;

View File

@@ -883,7 +883,7 @@ void explicit_gemm_conv_1D_cpu(
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
@@ -895,7 +895,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -920,7 +920,7 @@ void explicit_gemm_conv_1D_cpu(
// Materialize strided view
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -938,13 +938,13 @@ void explicit_gemm_conv_1D_cpu(
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
copy(wt_transpose, gemm_wt, CopyType::General, stream);
temps.push_back(gemm_wt);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
@@ -991,7 +991,7 @@ void explicit_gemm_conv_1D_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
@@ -1029,7 +1029,7 @@ void explicit_gemm_conv_2D_cpu(
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
@@ -1044,7 +1044,7 @@ void explicit_gemm_conv_2D_cpu(
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
@@ -1065,7 +1065,7 @@ void explicit_gemm_conv_2D_cpu(
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -1076,7 +1076,7 @@ void explicit_gemm_conv_2D_cpu(
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
@@ -1116,7 +1116,7 @@ void explicit_gemm_conv_2D_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
@@ -1156,7 +1156,7 @@ void explicit_gemm_conv_ND_cpu(
// Fill with zeros
std::vector<array> temps = {array(0, conv_dtype)};
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = 0;
@@ -1173,7 +1173,7 @@ void explicit_gemm_conv_ND_cpu(
data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
temps.push_back(in_padded_slice);
// Make strided view
@@ -1212,7 +1212,7 @@ void explicit_gemm_conv_ND_cpu(
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
copy(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
@@ -1223,13 +1223,13 @@ void explicit_gemm_conv_ND_cpu(
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
copy(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
copy(gemm_wt, gemm_wt_, CopyType::Vector, stream);
temps.push_back(gemm_wt_);
// Calculate the total size of the spatial dimensions
@@ -1284,7 +1284,7 @@ void explicit_gemm_conv_ND_cpu(
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
copy_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}

View File

@@ -295,11 +295,7 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream) {
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
@@ -309,7 +305,7 @@ void copy_cpu_inplace(
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
@@ -319,10 +315,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_cpu_inplace(src, dst, ctype, stream);
copy_inplace(src, dst, ctype, stream);
}
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -377,10 +373,4 @@ void copy_cpu_inplace(
});
}
array contiguous_copy_cpu(const array& arr, Stream stream) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::General, stream);
return arr_copy;
}
} // namespace mlx::core

View File

@@ -10,14 +10,10 @@
namespace mlx::core {
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
const array& src,
array& dst,
CopyType ctype,
Stream stream);
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_cpu_inplace(
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
@@ -30,7 +26,4 @@ void copy_cpu_inplace(
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_cpu(const array& arr, Stream stream);
} // namespace mlx::core

View File

@@ -13,7 +13,9 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
return {contiguous_copy_cpu(arr, stream), true};
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
@@ -32,7 +34,8 @@ void AllReduce::eval_cpu(
}
return in;
} else {
array arr_copy = contiguous_copy_cpu(in, s);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}

View File

@@ -135,7 +135,7 @@ void Eig::eval_cpu(
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
copy(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -196,7 +196,7 @@ void Eigh::eval_cpu(
values.set_data(allocator::malloc(values.nbytes()));
copy_cpu(
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy_cpu(
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -517,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
@@ -686,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(src, out, ctype, stream());
copy(src, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);

View File

@@ -115,7 +115,7 @@ void inverse_impl(
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy_cpu(
copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>

View File

@@ -87,7 +87,8 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}

View File

@@ -31,7 +31,7 @@ void luf_impl(
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_cpu_inplace(
copy_inplace(
a,
lu,
a.shape(),

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
@@ -53,58 +52,6 @@ inline void mask_matrix(
}
}
template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
Shape a_copy = a_shape;
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
for (int i = 0; i < num_segments; i++) {
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
a + k_start * a_strides[ndim - 1],
b + k_start * b_strides[ndim - 2],
out + i * M * N,
a_transposed,
b_transposed,
lda,
ldb,
N,
1.0,
0.0,
1,
a_copy,
a_strides,
b_copy,
b_strides);
}
}
} // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -124,20 +71,21 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true);
}
return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_copy, CopyType::Vector, s);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true);
}
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
array arr_copy = contiguous_copy_cpu(arr, s);
return std::make_tuple(false, stx, arr_copy, true);
}
};
@@ -385,7 +333,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -489,121 +437,4 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporaries(std::move(temps));
}
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cpu::get_command_encoder(stream());
auto check_transpose = [&s, &encoder](const array& x) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
if (stx == x.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, x);
} else if (stx == 1 && sty == x.shape(-2)) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy_cpu(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);
}
};
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
auto& segments = inputs[2];
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(segments);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
segments = array::unsafe_weak_copy(segments),
out_ptr = out.data<void>(),
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb]() {
switch (a.dtype()) {
case float64:
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float32:
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float16:
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case bfloat16:
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
default:
throw std::invalid_argument(
"Segmented mm supports only real float types.");
}
});
}
} // namespace mlx::core

View File

@@ -81,7 +81,7 @@ void matmul_general(
return std::make_tuple(true, sty, arr);
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, stream);
copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1);
return std::make_tuple(false, stx, temps.back());
}
@@ -142,7 +142,7 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}

View File

@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(in, out, ctype, stream());
copy(in, out, ctype, stream());
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_cpu(in, out, CopyType::General, stream());
copy(in, out, CopyType::General, stream());
}
}
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
ctype = CopyType::General;
}
copy_cpu(in, out, ctype, stream());
copy(in, out, ctype, stream());
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy_cpu(val, out, CopyType::Scalar, stream());
copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_cpu_inplace(
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else {
copy_cpu_inplace(in, tmp, CopyType::General, stream());
copy_inplace(in, tmp, CopyType::General, stream());
}
auto flags = out.flags();

View File

@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));

View File

@@ -529,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -579,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -712,7 +712,9 @@ void fast::AffineQuantize::eval_cpu(
if (arr.flags().row_contiguous) {
return std::make_pair(arr, false);
} else {
return std::make_pair(contiguous_copy_cpu(arr, s), true);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};

View File

@@ -325,15 +325,7 @@ struct MaxReduce {
};
template <int N, typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
};
@@ -350,15 +342,7 @@ struct MinReduce {
};
template <int N, typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
};
@@ -543,10 +527,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);

View File

@@ -250,8 +250,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
in = contiguous_copy_cpu(in, stream());
encoder.add_temporary(in);
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc(out.nbytes()));

View File

@@ -131,7 +131,8 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
}
return x;
} else {
array x_copy = contiguous_copy_cpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}

View File

@@ -8,7 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -333,24 +333,45 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
sort<MLX_GET_TYPE(type_tag)>(out, axis);
});
});
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -405,10 +426,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@@ -31,7 +31,7 @@ void svd_impl(
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
copy_cpu(
copy(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,

View File

@@ -6,59 +6,42 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation.
@@ -83,11 +66,6 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.
@@ -101,18 +79,11 @@ endif()
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
# and requires drivers released after CUDA 12.4.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
@@ -144,27 +115,6 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Use the frontend APIs of cuDNN.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
target_link_libraries(mlx PRIVATE cudnn_frontend)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -2,7 +2,7 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
#include "mlx/backend/cuda/worker.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
@@ -14,68 +14,14 @@ namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() {
auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {
curr->next = buffer_ + i;
curr = curr->next;
}
curr->next = nullptr;
}
SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
}
CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) {
return nullptr;
}
Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
}
void SmallSizePool::free(CudaBuffer* buf) {
auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_;
next_free_ = b;
}
bool SmallSizePool::in_pool(CudaBuffer* buf) {
constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
}
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
[this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -85,37 +31,22 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
} else if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
int64_t mem_to_free =
get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free);
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
// Try the scalar pool first
if (size <= small_block_size) {
buf = scalar_pool_.malloc();
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
}
@@ -126,6 +57,7 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
@@ -140,7 +72,9 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
cuda_free(buf);
lock.unlock();
cuda_free(buf->data);
delete buf;
}
}
@@ -152,14 +86,28 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size;
}
// This must be called with mutex_ aquired
void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
delete buf;
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const {

View File

@@ -7,10 +7,13 @@
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
@@ -19,35 +22,21 @@ struct CudaBuffer {
size_t size;
};
class SmallSizePool {
private:
union Block {
Block* next;
CudaBuffer buf;
};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
public:
SmallSizePool();
~SmallSizePool();
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
@@ -58,18 +47,19 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
void cuda_free(CudaBuffer* buf);
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();

View File

@@ -1,55 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
} // namespace mlx::core

View File

@@ -1,8 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -44,11 +43,8 @@ struct ArgMin {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
@@ -77,11 +73,8 @@ struct ArgMax {
}
template <int N>
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
@@ -112,15 +105,16 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
@@ -157,30 +151,36 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
auto kernel =
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
using InType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{BLOCK_DIM, 1, 1};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,
BLOCK_DIM,
N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = &cu::arg_reduce_general<
InType,
cu::ArgMin<InType>,
BLOCK_DIM,
N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
});
}

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
@@ -16,86 +17,35 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (int i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = Op{}(a[0], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[i]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a[0], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = Op{}(a[0], b[index]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = Op{}(a[index], b[0]);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = Op{}(a[index], b[index]);
}
}
@@ -128,7 +78,7 @@ __global__ void binary_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
@@ -175,12 +125,13 @@ constexpr bool supports_binary_op() {
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
@@ -189,119 +140,126 @@ void binary_op_gpu_inplace(
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] =
get_launch_args(out, large());
encoder.add_kernel_node(
cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
cu::binary_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
bool large = a.data_size() > INT32_MAX ||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
const_param<NDIM>(shape),
const_param<NDIM>(a_strides),
const_param<NDIM>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
std::vector<array> outputs{out};
binary_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, name(), s); \
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = outputs[0].primitive().stream(); \
binary_op_gpu<cu::func>(inputs, outputs, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
@@ -325,31 +283,33 @@ BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}

View File

@@ -1,333 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b[0]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_two_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_two_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const char* op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out_a.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) =
collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
encoder.add_kernel_node(
cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto [num_blocks, block_dims] =
get_launch_args(out_a, large());
encoder.add_kernel_node(
cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
}
template <typename Op>
void binary_two_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const char* op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
@@ -53,10 +52,9 @@ struct FusedKernelBuilder {
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <typename IdxT = uint32_t>\n";
} else {
os +=
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
os += "template <int NDIM, typename IdxT = uint32_t>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
@@ -68,77 +66,12 @@ struct FusedKernelBuilder {
}
os += ") {\n";
// Index. For non contiguous kernels we create a separate index
// variable per variable otherwise everyone uses `index`.
// Index.
os +=
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
" IdxT index = cg::this_grid().thread_rank();\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
if (!contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " IdxT " + xname + "_idx = 0;\n";
}
os += " {\n";
os += " IdxT loc = index;\n";
os +=
" #pragma unroll\n"
" for (int i = NDIM - 1; i >= 0; i--) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
"_strides[i]);\n";
}
os +=
" loc /= shape[i];\n"
" }\n"
" }\n";
}
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop
if (!contiguous) {
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -153,11 +86,14 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname);
} else if (contiguous) {
value = fmt::format("vec_{}[i]", xname);
value = fmt::format("{}[index]", xname);
} else {
value = fmt::format("{}[{}_idx]", xname, xname);
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
@@ -169,40 +105,21 @@ struct FusedKernelBuilder {
value = fmt::format(
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
} else {
value = x.primitive().name();
std::ostringstream ss;
x.primitive().print(ss);
value = ss.str();
value += "{}(";
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
}
}
os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
os += "}\n";
@@ -228,15 +145,6 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code.
cu::FusedKernelBuilder builder{
@@ -249,24 +157,16 @@ void Compiled::eval_gpu(
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
}
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
@@ -278,7 +178,6 @@ void Compiled::eval_gpu(
// Whether to use large index.
bool large = compiled_use_large_index(inputs, outputs, contiguous);
cu::KernelArgs args;
// Put inputs.
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -286,42 +185,35 @@ void Compiled::eval_gpu(
continue;
}
const auto& x = inputs[i];
args.append(x);
mod.append_arg(x);
if (!contiguous && !is_scalar(x)) {
args.append_ptr(strides_vec[strides_index++].data());
mod.append_arg(strides_vec[strides_index++]);
}
}
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) {
args.append(x);
mod.append_arg(x);
}
// Put shape and size.
if (!contiguous) {
args.append_ptr(shape.data());
mod.append_arg(shape);
}
if (large) {
args.append<int64_t>(outputs[0].data_size());
mod.append_arg<int64_t>(outputs[0].data_size());
} else {
args.append<uint32_t>(outputs[0].data_size());
}
// Choose work per thread
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
mod.append_arg<uint32_t>(outputs[0].data_size());
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name +=
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
kernel_name += fmt::format("_contiguous<{}>", index_type);
} else {
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
@@ -330,11 +222,9 @@ void Compiled::eval_gpu(
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, outputs[0], large);
});
}
} // namespace mlx::core

View File

@@ -1,546 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> dilation;
int groups;
bool flip;
uint8_t input_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
};
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
cache(/* capacity */ 128);
return cache;
}
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
array& x,
array& w,
array& y) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
x.data<void>(),
w.data<void>(),
y.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
const ConvCacheKey& cache_key,
cudnnBackendDescriptorType_t backend_type,
cudnn_frontend::EngineConfigList& configs,
const std::string& op_graph_tag,
array& x,
array& w,
array& y) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(plan)));
return true;
}
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return false;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
if (backend_type == CONV_BACKWARD_INPUT) {
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
}
}
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', x))
.setwDesc(build_tensor('w', w))
.setyDesc(build_tensor('y', y))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
return std::nullopt;
}
}
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array in,
array wt,
array out,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = swapaxes_in_eval(wt, 0, -1);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = swapaxes_in_eval(in, 0, -1);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
}
// cuDNN requires contiguous input.
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(final_out);
if (backend_type == CONV_BACKWARD_WEIGHT) {
// Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype),
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(kernel_strides_),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_dilation_),
groups_,
flip_,
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!execute_plan(encoder, plan, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
return;
}
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
} else {
// Otherwise it could be backward weight convolution or forward convolution,
// mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
// usually have |in| and |wt| transposed.
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
wt.shape(2) > out.shape(2)) {
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
} else {
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
}
}
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
}
if (!op_graph) {
throw std::runtime_error("[conv] Can not build op graph.");
}
// Get ready to execute the graph.
register_args(encoder, backend_type, in, wt, out, out_);
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
auto tag = op_graph->getTag();
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, dtype, *op_graph);
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
return;
}
throw std::runtime_error("[conv] Unable to find a working engine.");
}
} // namespace mlx::core

View File

@@ -24,6 +24,7 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;

View File

@@ -10,6 +10,15 @@
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,

View File

@@ -10,43 +10,19 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int N_READS>
template <typename In, typename Out, typename IdxT>
__global__ void copy_s(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = CastOp<In, Out>{}(in[0]);
}
}
template <typename In, typename Out, typename IdxT, int N_READS>
template <typename In, typename Out, typename IdxT>
__global__ void copy_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = cast_to<Out>(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = cast_to<Out>(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
if (index < size) {
out[index] = CastOp<In, Out>{}(in[index]);
}
}
@@ -59,24 +35,17 @@ void copy_contiguous(
array& out,
int64_t in_offset,
int64_t out_offset) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());

View File

@@ -37,7 +37,7 @@ __global__ void copy_gg(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
}
@@ -55,53 +55,39 @@ void copy_general(
const Shape& shape,
const Strides& strides_in,
const Strides& strides_out) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
encoder.add_kernel_node(
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] =
get_launch_args(data_size, shape, out.strides(), large());
encoder.add_kernel_node(
cu::copy_gg<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
});
});
}

View File

@@ -41,7 +41,7 @@ __global__ void copy_gg_dynamic(
const int64_t* offset_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
auto [idx_in, idx_out] = elem_to_loc_4d(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
}
@@ -61,56 +61,43 @@ void copy_general_dynamic(
const Strides& strides_out,
const array& dynamic_offset_in,
const array& dynamic_offset_out) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
cu::copy_gg_dynamic_nd<
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
cu::copy_gg_dynamic<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
});
});
}

View File

@@ -34,7 +34,7 @@ __global__ void copy_g(
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
IdxT idx_in = elem_to_loc_4d(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
}
}
@@ -50,46 +50,37 @@ void copy_general_input(
int64_t offset_out,
const Shape& shape,
const Strides& strides_in) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
cu::copy_g<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
});
});
}

View File

@@ -1,41 +1,38 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include "mlx/backend/metal/metal.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <future>
#include <unordered_set>
namespace mlx::core::cu {
namespace mlx::core {
namespace {
namespace cu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
DeviceStream::DeviceStream(Device& device) : device_(device), stream_(device) {}
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void DeviceStream::synchronize() {
cudaStreamSynchronize(stream_);
}
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
cudaStream_t DeviceStream::schedule_cuda_stream() {
// TODO: Return a stream that maximizes parallelism.
return stream_;
}
cudaStream_t DeviceStream::last_cuda_stream() {
return stream_;
}
CommandEncoder& DeviceStream::get_encoder() {
if (!encoder_) {
encoder_ = std::make_unique<CommandEncoder>(*this);
}
return *encoder_;
}
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
}();
return cache_size;
}
} // namespace
Device::Device(int device) : device_(device) {
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
@@ -52,18 +49,11 @@ Device::Device(int device) : device_(device) {
}
// The cublasLt handle is used by matmul.
make_current();
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
// Initialize the jit module cache here ensures it is not
// unloaded before any evaluation is done
get_jit_module_cache();
cublasLtCreate(&lt_);
}
Device::~Device() {
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
cublasLtDestroy(lt_);
}
void Device::make_current() {
@@ -76,256 +66,45 @@ void Device::make_current() {
}
}
CommandEncoder& Device::get_command_encoder(Stream s) {
auto it = encoders_.find(s.index);
if (it == encoders_.end()) {
it = encoders_.try_emplace(s.index, *this).first;
DeviceStream& Device::get_stream(Stream s) {
auto it = streams_.find(s.index);
if (it == streams_.end()) {
it = streams_.try_emplace(s.index, *this).first;
}
return it->second;
}
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) {
return;
}
enc.add_graph_node(graph);
}
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
: enc(enc) {
enc.in_concurrent_ = true;
}
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
}
// Insert the input -> concurrent node dependencies without updating output
// nodes
auto outputs = std::move(enc.active_outputs_);
enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));
// Update output node to be the empty node
for (auto o : outputs) {
enc.node_map_.emplace(o, empty).first->second = empty;
}
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
} else {
std::vector<GraphNode> nodes;
nodes.push_back(std::move(node));
insert_graph_dependencies(std::move(nodes));
}
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
// topology
std::unordered_set<cudaGraphNode_t> set_deps;
for (auto d : active_deps_) {
if (auto it = node_map_.find(d); it != node_map_.end()) {
auto [_, inserted] = set_deps.insert(it->second.node);
if (inserted) {
deps.push_back(it->second);
}
}
}
}
active_deps_.clear();
for (auto o : active_outputs_) {
for (auto& node : nodes) {
node_map_.emplace(o, node).first->second = node;
}
}
active_outputs_.clear();
for (auto& from : deps) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
}
}
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
CommandEncoder::CommandEncoder(DeviceStream& s)
: device_(s.device()), stream_(s) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::set_input_array(const array& arr) {
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
void CommandEncoder::end_encoding() {
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
void CommandEncoder::set_output_array(const array& arr) {
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
active_outputs_.push_back(id);
}
// There is no kernel running, run completion handlers immediately.
if (!has_gpu_work_) {
worker_.consume_in_this_thread();
return;
}
has_gpu_work_ = false;
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
// Put completion handlers in a batch.
worker_.end_batch();
// Signaling kernel completion is expensive, delay until enough batches.
// TODO: This number is arbitrarily picked, profile for a better stragety.
if (worker_.uncommited_batches() > 8) {
commit();
}
}
void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
}
void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
kernel_params.gridDimY = grid_dim.y;
kernel_params.gridDimZ = grid_dim.z;
kernel_params.blockDimX = block_dim.x;
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
}
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
// Put completion handlers in a batch.
worker_.commit(stream_);
}
void CommandEncoder::synchronize() {
cudaStreamSynchronize(stream_);
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
commit();
f.wait();
worker_.commit(stream_.last_cuda_stream());
}
Device& device(mlx::core::Device device) {
@@ -337,8 +116,14 @@ Device& device(mlx::core::Device device) {
return it->second;
}
CommandEncoder& get_command_encoder(Stream s) {
return device(s.device).get_command_encoder(s);
DeviceStream& get_stream(Stream s) {
return device(s.device).get_stream(s);
}
} // namespace mlx::core::cu
CommandEncoder& get_command_encoder(Stream s) {
return get_stream(s).get_encoder();
}
} // namespace cu
} // namespace mlx::core

View File

@@ -3,133 +3,45 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
#include <cublasLt.h>
#include <cuda.h>
#include <cudnn.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
namespace mlx::core::cu {
class CommandEncoder {
class Device;
class CommandEncoder;
class DeviceStream {
public:
struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CommandEncoder& enc;
bool discard{false};
};
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc);
~ConcurrentContext();
CommandEncoder& enc;
};
explicit DeviceStream(Device& device);
explicit CommandEncoder(Device& d);
DeviceStream(const DeviceStream&) = delete;
DeviceStream& operator=(const DeviceStream&) = delete;
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
// Wait until kernels in the stream complete.
void synchronize();
CaptureContext capture_context() {
return CaptureContext{*this};
}
ConcurrentContext concurrent_context() {
return ConcurrentContext{*this};
}
// Return a cuda stream for launching kernels.
cudaStream_t schedule_cuda_stream();
void set_input_array(const array& arr);
void set_output_array(const array& arr);
// Return the last cuda stream used.
cudaStream_t last_cuda_stream();
template <typename F, typename... Params>
void add_kernel_node(
F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void maybe_commit();
void commit();
CommandEncoder& get_encoder();
Device& device() {
return device_;
}
CudaStream& stream() {
return stream_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
struct GraphNode {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
std::string id;
};
void insert_graph_dependencies(GraphNode node);
void insert_graph_dependencies(std::vector<GraphNode> nodes);
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
std::unique_ptr<CommandEncoder> encoder_;
};
class Device {
@@ -143,7 +55,7 @@ class Device {
// Make this device the current cuda device, required by some cuda calls.
void make_current();
CommandEncoder& get_command_encoder(Stream s);
DeviceStream& get_stream(Stream s);
int cuda_device() const {
return device_;
@@ -157,20 +69,70 @@ class Device {
cublasLtHandle_t lt_handle() const {
return lt_;
}
cudnnHandle_t cudnn_handle() const {
return cudnn_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_;
std::unordered_map<int, DeviceStream> streams_;
};
class CommandEncoder {
public:
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
void set_input_array(const array& arr) {}
void set_output_array(const array& arr) {}
void add_temporary(const array& arr) {
temporaries_.push_back(arr.data_shared_ptr());
}
void add_completed_handler(std::function<void()> task);
void end_encoding();
void commit();
// Schedule a cuda stream for |fun| to launch kernels, and check error
// afterwards.
template <typename F>
void launch_kernel(F&& fun) {
launch_kernel(stream_.schedule_cuda_stream(), std::forward<F>(fun));
}
template <typename F>
void launch_kernel(cudaStream_t stream, F&& fun) {
device_.make_current();
fun(stream);
check_cuda_error("kernel launch", cudaGetLastError());
has_gpu_work_ = true;
}
Device& device() {
return device_;
}
DeviceStream& stream() {
return stream_;
}
bool has_gpu_work() const {
return has_gpu_work_;
}
private:
Device& device_;
DeviceStream& stream_;
Worker worker_;
bool has_gpu_work_{false};
std::vector<std::shared_ptr<array::Data>> temporaries_;
};
Device& device(mlx::core::Device device);
DeviceStream& get_stream(Stream s);
CommandEncoder& get_command_encoder(Stream s);
// Return an execution policy that does not sync for result.

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/backend/cuda/device/complex.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include <cuda/atomic>
@@ -48,16 +48,25 @@ inline __device__ void atomic_add(__half* out, __half val) {
atomicAdd(out, val);
}
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
atomic_add_general(out, val);
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
#if __CUDA_ARCH__ < 800
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
#if __CUDA_ARCH__ < 800
#if CCCL_VERSION >= 2008000
atomic_add_general(out, val);
#else
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
assert(cccl_version_too_old_for_bfloat16_atomic_add);
#endif
#else
atomicAdd(out, val);
#endif
}
} // namespace mlx::core::cu

View File

@@ -1,7 +1,10 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
@@ -19,7 +22,7 @@ struct FloorDivide {
if constexpr (cuda::std::is_integral_v<T>) {
return x / y;
} else {
return truncf(x / y);
return trunc(x / y);
}
}
};
@@ -44,7 +47,7 @@ struct Remainder {
} else {
return x % y;
}
} else if constexpr (is_complex_v<T>) {
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return x % y;
} else {
T r = fmod(x, y);
@@ -66,12 +69,14 @@ struct Equal {
struct NaNEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (is_complex_v<T>) {
if constexpr (std::is_same_v<T, cuComplex>) {
return x == y ||
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
isnan(y.imag())) ||
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(cuCrealf(x) == cuCrealf(y) && isnan(cuCimagf(x)) &&
isnan(cuCimagf(y))) ||
(isnan(cuCrealf(x)) && isnan(cuCrealf(y)) &&
cuCimagf(x) == cuCimagf(y));
} else {
return x == y || (isnan(x) && isnan(y));
}
@@ -109,38 +114,36 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
isnan(y.imag())) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = x.real() > y.real() ? x : y;
auto min = x.real() < y.real() ? x : y;
auto min_real = min.real();
auto max_real = max.real();
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
} else {
return Log{}(Exp{}(min) + Exp{}(max));
}
} else {
return Log1p{}(Exp{}(min - max)) + max;
}
} else {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
constexpr float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
};
struct Maximum {
@@ -148,8 +151,8 @@ struct Maximum {
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return max(x, y);
} else if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag())) {
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x > y ? x : y;
@@ -167,8 +170,8 @@ struct Minimum {
__device__ T operator()(T x, T y) {
if constexpr (cuda::std::is_integral_v<T>) {
return min(x, y);
} else if constexpr (is_complex_v<T>) {
if (isnan(x.real()) || isnan(x.imag())) {
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x))) {
return x;
}
return x < y ? x : y;
@@ -191,8 +194,8 @@ struct Multiply {
struct NotEqual {
template <typename T>
__device__ bool operator()(T x, T y) {
if constexpr (is_complex_v<T>) {
return x.real() != y.real() || x.imag() != y.imag();
if constexpr (std::is_same_v<T, cuComplex>) {
return cuCrealf(x) != cuCrealf(y) || cuCimagf(x) != cuCimagf(y);
} else {
return x != y;
}
@@ -212,8 +215,19 @@ struct Power {
base *= base;
}
return res;
} else if constexpr (is_complex_v<T>) {
return pow(base, exp);
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
auto phase = exp.y * x_ln_r + exp.x * x_theta;
return make_cuFloatComplex(mag * cosf(phase), mag * sinf(phase));
} else {
return powf(base, exp);
}

View File

@@ -2,10 +2,7 @@
#pragma once
#include "mlx/backend/cuda/device/complex.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuComplex.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
@@ -20,48 +17,34 @@ struct CastOp {
}
};
// Castings between complex and boolean.
template <typename T>
struct CastOp<complex_t<T>, bool> {
static constexpr bool is_castable = true;
__device__ bool operator()(complex_t<T> x) {
return x.real() != 0 && x.imag() != 0;
}
};
template <typename T>
struct CastOp<bool, complex_t<T>> {
static constexpr bool is_castable = true;
__device__ complex_t<T> operator()(bool x) {
return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
}
};
// Converting a complex number to real number discards the imaginary part.
template <typename T, typename DstT>
struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;
template <typename DstT>
struct CastOp<
cuComplex,
DstT,
cuda::std::enable_if_t<!cuda::std::is_same_v<cuComplex, DstT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<float, DstT>;
__device__ DstT operator()(complex_t<T> x) {
static_assert(!is_complex_v<DstT>);
return static_cast<DstT>(x.real());
__device__ DstT operator()(cuComplex x) {
static_assert(!cuda::std::is_same_v<cuComplex, DstT>);
return static_cast<DstT>(cuCrealf(x));
}
};
// Allow converting a real number to complex number.
template <typename SrcT, typename T>
struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;
template <typename SrcT>
struct CastOp<
SrcT,
cuComplex,
cuda::std::enable_if_t<!cuda::std::is_same_v<SrcT, cuComplex>>> {
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, float>;
__device__ complex_t<T> operator()(SrcT x) {
static_assert(!is_complex_v<SrcT>);
return complex_t<T>{static_cast<T>(x), 0};
__device__ cuComplex operator()(SrcT x) {
static_assert(!cuda::std::is_same_v<SrcT, cuComplex>);
return cuComplex{static_cast<float>(x), 0};
}
};
// Do nothing when no casting is needed.
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
@@ -74,51 +57,9 @@ struct CastOp<
}
};
// In CUDA 11 the half types do not define conversions between some types,
// provide fallbacks here.
#if CUDART_VERSION < 12000
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
(cuda::std::is_same_v<DstT, __half> ||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
!cuda::std::is_same_v<DstT, __half> &&
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
(cuda::std::is_same_v<SrcT, __half> ||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
static constexpr bool is_castable = true;
__device__ DstT operator()(SrcT x) {
return DstT(static_cast<float>(x));
}
};
#endif // CUDART_VERSION < 12000
// Helper to deduce the SrcT.
template <typename DstT, typename SrcT>
inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
__host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;

View File

@@ -1,60 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
// Make multiplication and division faster.
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
#include <cuda/std/complex>
#include <cuda/std/type_traits>
namespace mlx::core::cu {
// TODO: Consider using a faster implementation as cuda::std::complex has to
// conform to C++ standard.
template <typename T>
using complex_t = cuda::std::complex<T>;
using complex64_t = complex_t<float>;
using complex128_t = complex_t<double>;
template <typename T>
struct is_complex : cuda::std::false_type {};
template <typename T>
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;
// cuda::std::complex is missing some operators.
template <typename T>
inline __host__ __device__ complex_t<T> operator%(
complex_t<T> a,
complex_t<T> b) {
T r = a.real() - floor(a.real() / b.real()) * b.real();
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
return complex_t<T>{r, i};
}
template <typename T>
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
}
template <typename T>
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
return operator>(b, a);
}
template <typename T>
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
return !(a > b);
}
template <typename T>
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
return !(a < b);
}
} // namespace mlx::core::cu

View File

@@ -5,7 +5,7 @@
#pragma once
// The maximum dimensions of shape/strides passed as kernel parameters.
#define MAX_NDIM 10
#define MAX_NDIM 8
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
// warpSize variable exists, using it would prevent compile-time optimizations.

View File

@@ -0,0 +1,240 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2017-2024 The Simons Foundation, Inc.
//
// FINUFFT is licensed under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/flatironinstitute/finufft/blob/main/include/cufinufft/contrib/helper_math.h
#pragma once
#include <cuComplex.h>
// This header provides some helper functions for cuComplex types.
// It mainly wraps existing CUDA implementations to provide operator overloads
// e.g. cuAdd, cuSub, cuMul, cuDiv, cuCreal, cuCimag, cuCabs, cuCarg, cuConj are
// all provided by CUDA
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCadd(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCsub(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCmul(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, const cuDoubleComplex& b) {
return cuCdiv(a, b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator%(const cuDoubleComplex& a, const cuDoubleComplex& b) {
double r = cuCreal(a) - (floorf(cuCreal(a) / cuCreal(b)) * cuCreal(b));
double i = cuCimag(a) - (floorf(cuCimag(a) / cuCimag(b)) * cuCimag(b));
return make_cuDoubleComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuCreal(a) == cuCreal(b) && cuCimag(a) == cuCimag(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
double mag_a = sqrt(cuCreal(a) * cuCreal(a) + cuCimag(a) * cuCimag(a));
double mag_b = sqrt(cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) + b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator+(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a + cuCreal(b), cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) - b, cuCimag(a));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator-(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a - cuCreal(b), -cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) * b, cuCimag(a) * b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator*(double a, const cuDoubleComplex& b) {
return make_cuDoubleComplex(a * cuCreal(b), a * cuCimag(b));
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(const cuDoubleComplex& a, double b) {
return make_cuDoubleComplex(cuCreal(a) / b, cuCimag(a) / b);
}
__forceinline__ __host__ __device__ cuDoubleComplex
operator/(double a, const cuDoubleComplex& b) {
double denom = cuCreal(b) * cuCreal(b) + cuCimag(b) * cuCimag(b);
return make_cuDoubleComplex(
(a * cuCreal(b)) / denom, (-a * cuCimag(b)) / denom);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCaddf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCsubf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCmulf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, const cuFloatComplex& b) {
return cuCdivf(a, b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator%(const cuFloatComplex& a, const cuFloatComplex& b) {
float r = cuCrealf(a) - (floorf(cuCrealf(a) / cuCrealf(b)) * cuCrealf(b));
float i = cuCimagf(a) - (floorf(cuCimagf(a) / cuCimagf(b)) * cuCimagf(b));
return make_cuFloatComplex(r, i);
}
__forceinline__ __host__ __device__ bool operator==(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return cuCrealf(a) == cuCrealf(b) && cuCimagf(a) == cuCimagf(b);
}
__forceinline__ __host__ __device__ bool operator!=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return !(a == b);
}
__forceinline__ __host__ __device__ bool operator>(
const cuFloatComplex& a,
const cuFloatComplex& b) {
float mag_a = sqrt(cuCrealf(a) * cuCrealf(a) + cuCimagf(a) * cuCimagf(a));
float mag_b = sqrt(cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b));
return mag_a > mag_b;
}
__forceinline__ __host__ __device__ bool operator>=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return a > b || a == b;
}
__forceinline__ __host__ __device__ bool operator<(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a;
}
__forceinline__ __host__ __device__ bool operator<=(
const cuFloatComplex& a,
const cuFloatComplex& b) {
return b > a || a == b;
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) + b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator+(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a + cuCrealf(b), cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) - b, cuCimagf(a));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator-(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a - cuCrealf(b), -cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) * b, cuCimagf(a) * b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator*(float a, const cuFloatComplex& b) {
return make_cuFloatComplex(a * cuCrealf(b), a * cuCimagf(b));
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(const cuFloatComplex& a, float b) {
return make_cuFloatComplex(cuCrealf(a) / b, cuCimagf(a) / b);
}
__forceinline__ __host__ __device__ cuFloatComplex
operator/(float a, const cuFloatComplex& b) {
float denom = cuCrealf(b) * cuCrealf(b) + cuCimagf(b) * cuCimagf(b);
return make_cuFloatComplex(
(a * cuCrealf(b)) / denom, (-a * cuCimagf(b)) / denom);
}

View File

@@ -14,6 +14,8 @@ struct Abs {
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0};
} else {
return abs(x);
}
@@ -74,8 +76,6 @@ struct Ceil {
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else if constexpr (is_complex_v<T>) {
return T{ceil(x.real()), ceil(x.imag())};
} else {
return ceil(x);
}
@@ -83,23 +83,34 @@ struct Ceil {
};
struct Conjugate {
template <typename T>
__device__ complex_t<T> operator()(complex_t<T> x) {
return conj(x);
__device__ cuComplex operator()(cuComplex x) {
return {cuCrealf(x), -cuCimagf(x)};
}
};
struct Cos {
template <typename T>
__device__ T operator()(T x) {
return cos(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cos(cuCrealf(x)) * cosh(cuCimagf(x)),
-sin(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return cos(x);
}
}
};
struct Cosh {
template <typename T>
__device__ T operator()(T x) {
return cosh(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
cosh(cuCrealf(x)) * cos(cuCimagf(x)),
sinh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return cosh(x);
}
}
};
@@ -132,7 +143,12 @@ struct ErfInv {
struct Exp {
template <typename T>
__device__ T operator()(T x) {
return exp(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
} else {
return exp(x);
}
}
};
@@ -154,8 +170,6 @@ struct Floor {
__device__ T operator()(T x) {
if constexpr (cuda::std::is_integral_v<T>) {
return x;
} else if constexpr (is_complex_v<T>) {
return T{floor(x.real()), floor(x.imag())};
} else {
return floor(x);
}
@@ -163,25 +177,30 @@ struct Floor {
};
struct Imag {
template <typename T>
__device__ auto operator()(complex_t<T> x) {
return x.imag();
__device__ float operator()(cuComplex x) {
return cuCimagf(x);
}
};
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
if constexpr (is_complex_v<T>) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else {
return log2(x);
}
@@ -191,31 +210,20 @@ struct Log2 {
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
}
};
struct Log1p {
template <typename T>
__device__ T operator()(T z) {
if constexpr (is_complex_v<T>) {
float x = z.real();
float y = z.imag();
float zabs = Abs{}(z).real();
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
__device__ T operator()(T x) {
return log1p(x);
}
};
@@ -228,8 +236,8 @@ struct LogicalNot {
struct Negative {
template <typename T>
__device__ T operator()(T x) {
if constexpr (is_complex_v<T>) {
return T{0, 0} - x;
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return 0 - x;
} else {
return -x;
}
@@ -237,23 +245,29 @@ struct Negative {
};
struct Real {
template <typename T>
__device__ auto operator()(complex_t<T> x) {
return x.real();
__device__ float operator()(cuComplex x) {
return cuCrealf(x);
}
};
struct Round {
template <typename T>
__device__ T operator()(T x) {
if constexpr (is_complex_v<T>) {
return {rint(x.real()), rint(x.imag())};
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {rint(cuCrealf(x)), rint(cuCimagf(x))};
} else {
return rint(x);
}
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
@@ -267,8 +281,8 @@ struct Sign {
__device__ T operator()(T x) {
if constexpr (cuda::std::is_unsigned_v<T>) {
return x != 0;
} else if constexpr (is_complex_v<T>) {
if (x.real() == 0 && x.imag() == 0) {
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (cuCrealf(x) == 0 && cuCimagf(x) == 0) {
return x;
} else {
return x / Abs()(x);
@@ -284,14 +298,26 @@ struct Sign {
struct Sin {
template <typename T>
__device__ T operator()(T x) {
return sin(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sin(cuCrealf(x)) * cosh(cuCimagf(x)),
cos(cuCrealf(x)) * sinh(cuCimagf(x))};
} else {
return sin(x);
}
}
};
struct Sinh {
template <typename T>
__device__ T operator()(T x) {
return sinh(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return {
sinh(cuCrealf(x)) * cos(cuCimagf(x)),
cosh(cuCrealf(x)) * sin(cuCimagf(x))};
} else {
return sinh(x);
}
}
};
@@ -309,28 +335,33 @@ struct Sqrt {
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
if constexpr (is_complex_v<T>) {
return 1.0f / Sqrt{}(x);
} else {
return rsqrt(x);
}
}
};
struct Tan {
template <typename T>
__device__ T operator()(T x) {
return tan(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tan_a = tan(cuCrealf(x));
float tanh_b = tanh(cuCimagf(x));
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
} else {
return tan(x);
}
}
};
struct Tanh {
template <typename T>
__device__ T operator()(T x) {
return tanh(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float tanh_a = tanh(cuCrealf(x));
float tan_b = tan(cuCimagf(x));
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
} else {
return tanh(x);
}
}
};

View File

@@ -8,9 +8,9 @@
#pragma once
#include "mlx/backend/cuda/device/complex.cuh"
#include "mlx/backend/cuda/device/config.h"
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda/std/array>
@@ -28,124 +28,6 @@ namespace mlx::core::cu {
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
// Vectorized load/store.
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
__device__ T& operator[](int i) {
return val[i];
}
__device__ T operator[](int i) const {
return val[i];
}
};
template <int N, typename T>
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
if (is_aligned<N>(ptr)) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = ptr[offset * N + i];
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N>
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset,
SizeT size,
int64_t stride,
T fallback) {
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] =
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
}
return v;
}
}
template <int N, typename T>
inline __device__ void
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
if (is_aligned<N>(ptr)) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
@@ -196,20 +78,20 @@ struct Limits<
return cuda::std::numeric_limits<T>::infinity();
}
static constexpr __host__ __device__ T min() {
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return -cuda::std::numeric_limits<float>::infinity();
#else
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return -cuda::std::numeric_limits<T>::infinity();
#else
return -cuda::std::numeric_limits<float>::infinity();
#endif
}
static constexpr __host__ __device__ T finite_max() {
return cuda::std::numeric_limits<T>::max();
}
static constexpr __host__ __device__ T finite_min() {
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
return cuda::std::numeric_limits<float>::lowest();
#else
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
return cuda::std::numeric_limits<T>::lowest();
#else
return cuda::std::numeric_limits<float>::lowest();
#endif
}
};
@@ -224,13 +106,13 @@ struct Limits<bool> {
}
};
template <typename T>
struct Limits<complex_t<T>> {
static constexpr __host__ __device__ complex_t<T> max() {
return {Limits<T>::max(), Limits<T>::max()};
template <>
struct Limits<cuComplex> {
static constexpr __host__ __device__ cuComplex max() {
return {Limits<float>::max(), Limits<float>::max()};
}
static constexpr __host__ __device__ complex_t<T> min() {
return {Limits<T>::min(), Limits<T>::min()};
static constexpr __host__ __device__ cuComplex min() {
return {Limits<float>::min(), Limits<float>::min()};
}
};
@@ -273,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@@ -293,16 +175,28 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
}
// Optimized version when ndim is larger than 4.
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
inline __host__ __device__ IdxT
elem_to_loc_4d(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
@@ -312,15 +206,15 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
}
template <typename IdxT = int64_t>
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT elem,
const int* shape,
const int64_t* a_strides,
@@ -332,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@@ -444,4 +338,21 @@ struct LoopedElemToLoc<1, false, OffsetT> {
}
};
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu

View File

@@ -19,6 +19,8 @@ void new_stream(Stream s) {
cudaFree(nullptr);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
}
void eval(array& arr) {
@@ -35,17 +37,22 @@ void eval(array& arr) {
}
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) {
// Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
if (encoder.has_gpu_work()) {
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
}
for (auto& s : arr.siblings()) {
encoder.add_temporary(s);
}
encoder.maybe_commit();
encoder.end_encoding();
}
void finalize(Stream s) {
@@ -55,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_command_encoder(s).synchronize();
cu::get_stream(s).synchronize();
}
} // namespace mlx::core::gpu

View File

@@ -61,9 +61,7 @@ void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
wait(cu::get_stream(s).last_cuda_stream());
}
}
@@ -76,9 +74,7 @@ void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
record(cu::get_stream(s).last_cuda_stream());
}
}
@@ -90,6 +86,8 @@ bool CudaEvent::completed() const {
// SharedEvent implementations
///////////////////////////////////////////////////////////////////////////////
namespace {
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
@@ -110,26 +108,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
} // namespace
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr);
delete ptr;
});
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
// Allocate cuda::atomic on managed memory.
Atomic* ac;
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic();
allocator().cuda_free(ptr);
});
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(to_atomic(buf_), value);
event_wait(ac_.get(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
@@ -138,19 +136,21 @@ void SharedEvent::wait(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.commit();
wait(encoder.stream(), value);
encoder.add_completed_handler([buf = buf_]() {});
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { wait(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(to_atomic(buf_), value);
event_signal(ac_.get(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
@@ -162,20 +162,22 @@ void SharedEvent::signal(Stream s, uint64_t value) {
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.commit();
signal(encoder.stream(), value);
encoder.add_completed_handler([buf = buf_]() {});
encoder.launch_kernel(
encoder.stream().last_cuda_stream(),
[this, value](cudaStream_t stream) { signal(stream, value); });
encoder.add_completed_handler([ac = ac_]() {});
encoder.end_encoding();
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return to_atomic(buf_)->load() >= value;
return ac_->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return to_atomic(buf_)->load();
return ac_->load();
}
} // namespace cu

View File

@@ -2,7 +2,6 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h"
#include <cuda_runtime.h>
@@ -56,8 +55,12 @@ class SharedEvent {
bool is_signaled(uint64_t value) const;
uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
std::shared_ptr<Atomic> ac_;
};
} // namespace mlx::core::cu

View File

@@ -1,73 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu {
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
a_it.step();
b_it.step();
}
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = out.size() / (M_ * N_ * batch_shape.back());
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha,
beta);
a_it.step();
b_it.step();
c_it.step();
}
}
} // namespace mlx::core::cu

View File

@@ -1,208 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -1,284 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
#include <fmt/format.h>
namespace mlx::core::cu {
struct CublasPreference {
CublasPreference(Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
~CublasPreference() {
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
}
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride)
: handle_(device.lt_handle()),
pref_(cublas_preference(device)),
M_(a_rows),
N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
Matmul::Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: Matmul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void Matmul::run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha /* = 1 */,
float beta /* = 0 */) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
handle_,
matmul_desc_,
a_desc_,
b_desc_,
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_,
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace_ptr,
heuristic_.workspaceSize,
encoder.stream()));
}
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -1,100 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
public:
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* c,
float alpha = 1,
float beta = 0);
uint64_t M_;
uint64_t N_;
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu

View File

@@ -1,173 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
__device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto g_idx = block.group_index();
auto t_idx = block.thread_index();
int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
}
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum);
}
}
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void
gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) {
gemv_impl<T, rows_per_block, n_per_thread>(mat, vec, out, rows, cols);
}
template <typename T, int rows_per_block, int n_per_thread>
__global__ void gemv_batched(
const T* mat,
const T* vec,
T* out,
int rows,
int cols,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides mat_batch_strides,
const __grid_constant__ Strides vec_batch_strides,
int batch_ndim) {
auto block = cg::this_thread_block();
auto batch_idx = block.group_index().y;
auto [vec_offset, mat_offset] = elem_to_loc(
batch_idx,
batch_shape.data(),
vec_batch_strides.data(),
mat_batch_strides.data(),
batch_ndim);
gemv_impl<T, rows_per_block, n_per_thread>(
mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols);
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
}
}
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat;
const DataType* vec;
int rows;
int cols = K;
auto mat_strides = const_param(a_batch_strides);
auto vec_strides = const_param(b_batch_strides);
if (M == 1) {
mat = b.data<DataType>();
vec = a.data<DataType>();
rows = N;
std::swap(mat_strides, vec_strides);
} else {
mat = a.data<DataType>();
vec = b.data<DataType>();
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
});
}
} // namespace mlx::core::cu

View File

@@ -1,24 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
namespace mlx::core::cu {
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
void gemv(
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
uint32_t batch_count,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -3,16 +3,13 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "cuda_jit_sources.h"
#include <cuda.h>
#include <fmt/format.h>
#include <nvrtc.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
@@ -25,31 +22,31 @@ namespace {
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
void append_indices_arg(
cu::KernelArgs& args,
cu::JitModule& mod,
const std::vector<array>& inputs,
int nidx,
int idx_ndim) {
SmallVector<const void*> indices(nidx);
std::vector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
}
args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
mod.append_arg(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].shape().begin(),
idx_ndim,
indices_shape.data() + i * idx_ndim);
}
args.append(std::move(indices_shape));
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
mod.append_arg(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].strides().begin(),
idx_ndim,
indices_strides.data() + i * idx_ndim);
}
args.append(std::move(indices_strides));
mod.append_arg(std::move(indices_strides));
}
} // namespace
@@ -97,21 +94,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_gather, std::move(kernel_names));
});
cu::KernelArgs args;
args.append(src);
args.append(out);
mod.append_arg(src);
mod.append_arg(out);
if (large) {
args.append<int64_t>(out.size());
mod.append_arg<int64_t>(out.size());
} else {
args.append<int32_t>(out.size());
mod.append_arg<int32_t>(out.size());
}
args.append_ndim(src.shape());
args.append_ndim(src.strides());
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
mod.append_ndim_arg(src.shape());
mod.append_ndim_arg(src.strides());
mod.append_arg<int32_t>(src.ndim());
mod.append_ndim_arg(slice_sizes_);
mod.append_arg(slice_size);
mod.append_arg(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
@@ -126,10 +122,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, out, large);
});
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -192,27 +187,26 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
return std::make_pair(jit_source_scatter, std::move(kernel_names));
});
cu::KernelArgs args;
args.append(upd);
args.append(out);
mod.append_arg(upd);
mod.append_arg(out);
if (large) {
args.append<int64_t>(upd.size());
mod.append_arg<int64_t>(upd.size());
} else {
args.append<int32_t>(upd.size());
mod.append_arg<int32_t>(upd.size());
}
args.append_ndim(upd.shape());
args.append_ndim(upd.strides());
args.append<int32_t>(upd.ndim());
mod.append_ndim_arg(upd.shape());
mod.append_ndim_arg(upd.strides());
mod.append_arg<int32_t>(upd.ndim());
if (large) {
args.append<int64_t>(upd_post_idx_size);
mod.append_arg<int64_t>(upd_post_idx_size);
} else {
args.append<int32_t>(upd_post_idx_size);
mod.append_arg<int32_t>(upd_post_idx_size);
}
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
mod.append_ndim_arg(out.shape());
mod.append_ndim_arg(out.strides());
mod.append_arg<int32_t>(out.ndim());
mod.append_arg(axes_);
append_indices_arg(mod, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
@@ -228,9 +222,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, upd, large);
});
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -281,26 +275,25 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
size_t idx_size_axis = idx.shape(axis_);
cu::KernelArgs args;
args.append(src);
args.append(idx);
args.append(out);
mod.append_arg(src);
mod.append_arg(idx);
mod.append_arg(out);
if (large) {
args.append<int64_t>(idx_size_pre);
args.append<int64_t>(idx_size_axis);
args.append<int64_t>(idx_size_post);
mod.append_arg<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
args.append<int32_t>(idx_size_pre);
args.append<int32_t>(idx_size_axis);
args.append<int32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
args.append(remove_index(idx.shape(), axis_));
args.append(remove_index(src.strides(), axis_));
args.append(remove_index(idx.strides(), axis_));
args.append<int32_t>(axis_);
args.append(src.shape(axis_));
args.append(src.strides(axis_));
args.append(idx.strides(axis_));
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(src.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_);
mod.append_arg(src.shape(axis_));
mod.append_arg(src.strides(axis_));
mod.append_arg(idx.strides(axis_));
std::string kernel_name = fmt::format(
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
@@ -316,9 +309,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, idx, large);
});
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -384,26 +377,25 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
size_t idx_size_axis = idx.shape(axis_);
cu::KernelArgs args;
args.append(upd);
args.append(idx);
args.append(out);
mod.append_arg(upd);
mod.append_arg(idx);
mod.append_arg(out);
if (large) {
args.append<int64_t>(idx_size_pre);
args.append<int64_t>(idx_size_axis);
args.append<int64_t>(idx_size_post);
mod.append_arg<int64_t>(idx_size_pre);
mod.append_arg<int64_t>(idx_size_axis);
mod.append_arg<int64_t>(idx_size_post);
} else {
args.append<int32_t>(idx_size_pre);
args.append<int32_t>(idx_size_axis);
args.append<int32_t>(idx_size_post);
mod.append_arg<int32_t>(idx_size_pre);
mod.append_arg<int32_t>(idx_size_axis);
mod.append_arg<int32_t>(idx_size_post);
}
args.append(remove_index(idx.shape(), axis_));
args.append(remove_index(upd.strides(), axis_));
args.append(remove_index(idx.strides(), axis_));
args.append<int32_t>(axis_);
args.append(out.shape(axis_));
args.append(upd.strides(axis_));
args.append(idx.strides(axis_));
mod.append_arg(remove_index(idx.shape(), axis_));
mod.append_arg(remove_index(upd.strides(), axis_));
mod.append_arg(remove_index(idx.strides(), axis_));
mod.append_arg<int32_t>(axis_);
mod.append_arg(out.shape(axis_));
mod.append_arg(upd.strides(axis_));
mod.append_arg(idx.strides(axis_));
std::string kernel_name = fmt::format(
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
@@ -420,9 +412,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
encoder.launch_kernel([&](cudaStream_t stream) {
mod.launch_kernel(stream, kernel_name, idx, large);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,121 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <cuda/std/utility>
#include "mlx/backend/cuda/kernel_utils.cuh"
namespace mlx::core::cu {
// Iterating non-contiguous array.
template <typename Iterator, typename IdxT = int64_t>
class general_iterator
: public thrust::
iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<general_iterator<Iterator, IdxT>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides)
: super_t(it),
index_(index),
ndim_(ndim),
shape_(cuda::std::move(shape)),
strides_(cuda::std::move(strides)) {}
__host__ __device__ IdxT index() const {
return index_;
}
__host__ __device__ const Shape& shape() const {
return shape_;
}
__host__ __device__ const Strides& strides() const {
return strides_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const general_iterator& other) const {
return this->base() == other.base() && this->index() == other.index();
}
__host__ __device__ void advance(difference_type n) {
this->index_ += n;
}
__host__ __device__ void increment() {
this->index_ += 1;
}
__host__ __device__ void decrement() {
this->index_ -= 1;
}
__host__ __device__ difference_type
distance_to(const general_iterator& other) const {
_CCCL_ASSERT(
this->base() == other.base(),
"Underlying iterator must point to same base iterator");
return other.index() - this->index();
}
// The dereference is device-only to avoid accidental running in host.
__device__ typename super_t::reference dereference() const {
IdxT offset = elem_to_loc(index_, shape_.data(), strides_.data(), ndim_);
return *(this->base() + offset);
}
IdxT index_;
int ndim_;
Shape shape_;
Strides strides_;
};
template <typename IdxT, typename Iterator>
__host__ __device__ auto make_general_iterator(
Iterator it,
IdxT index,
int ndim,
Shape shape,
Strides strides) {
return general_iterator<Iterator, IdxT>(
it, index, ndim, cuda::std::move(shape), cuda::std::move(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterator(
Iterator it,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
return make_general_iterator<IdxT>(
it, 0, shape.size(), const_param(shape), const_param(strides));
}
template <typename IdxT, typename Iterator>
auto make_general_iterators(
Iterator it,
IdxT size,
const std::vector<int32_t>& shape,
const std::vector<int64_t>& strides) {
auto ndim = shape.size();
auto shape_arg = const_param(shape);
auto strides_arg = const_param(strides);
return std::make_pair(
make_general_iterator<IdxT>(it, 0, ndim, shape_arg, strides_arg),
make_general_iterator<IdxT>(it, size, ndim, shape_arg, strides_arg));
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <thrust/iterator/iterator_adaptor.h>
#include <thrust/iterator/iterator_facade.h>
namespace mlx::core::cu {
// RandomAccessIterator for strided access to array entries.
template <typename Iterator, typename Stride = int64_t>
class strided_iterator
: public thrust::
iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator> {
public:
using super_t =
thrust::iterator_adaptor<strided_iterator<Iterator, Stride>, Iterator>;
using reference = typename super_t::reference;
using difference_type = typename super_t::difference_type;
__host__ __device__ strided_iterator(Iterator it, Stride stride)
: super_t(it), stride_(stride) {}
__host__ __device__ Stride stride() const {
return stride_;
}
private:
friend class thrust::iterator_core_access;
__host__ __device__ bool equal(const strided_iterator& other) const {
return this->base() == other.base();
}
__host__ __device__ void advance(difference_type n) {
this->base_reference() += n * stride_;
}
__host__ __device__ void increment() {
this->base_reference() += stride_;
}
__host__ __device__ void decrement() {
this->base_reference() -= stride_;
}
__host__ __device__ difference_type
distance_to(const strided_iterator& other) const {
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(
dist % stride() == 0,
"Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
Stride stride_;
};
} // namespace mlx::core::cu

View File

@@ -2,17 +2,16 @@
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/version.h"
#include "cuda_jit_sources.h"
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <fmt/format.h>
#include <nvrtc.h>
#include <unistd.h>
namespace mlx::core::cu {
@@ -27,74 +26,47 @@ void check_nvrtc_error(const char* name, nvrtcResult err) {
}
}
// Return the location of the CUDA toolkit.
const std::string& cuda_home() {
static std::string home = []() -> std::string {
const char* home = std::getenv("CUDA_HOME");
if (home) {
return home;
}
home = std::getenv("CUDA_PATH");
if (home) {
return home;
}
#if defined(__linux__)
home = "/usr/local/cuda";
if (std::filesystem::exists(home)) {
return home;
}
#endif
throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
}();
return home;
#define CHECK_CU_ERROR(cmd) check_cu_error(#cmd, (cmd))
void check_cu_error(const char* name, CUresult err) {
if (err != CUDA_SUCCESS) {
const char* err_str = "Unknown error";
cuGetErrorString(err, &err_str);
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
}
}
// Return the location of CCCL headers shipped with the distribution.
const std::string& cccl_dir() {
static std::string dir = []() {
std::filesystem::path path;
#if defined(MLX_CCCL_DIR)
// First search the install dir if defined.
path = MLX_CCCL_DIR;
if (std::filesystem::exists(path)) {
return path.string();
}
// Return the location of the CUDA toolkit.
const char* cuda_home() {
const char* home = std::getenv("CUDA_HOME");
if (home) {
return home;
}
home = std::getenv("CUDA_PATH");
if (home) {
return home;
}
#if defined(__linux__)
home = "/usr/local/cuda";
if (std::filesystem::exists(home)) {
return home;
}
#endif
// Then search dynamically from the dir of libmlx.so file.
path = current_binary_dir().parent_path() / "include" / "cccl";
if (std::filesystem::exists(path)) {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
return std::string();
}();
return dir;
throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
}
// Get the cache directory for storing compiled results.
const std::filesystem::path& ptx_cache_dir() {
static std::filesystem::path cache = []() -> std::filesystem::path {
std::filesystem::path cache;
if (auto c = std::getenv("MLX_PTX_CACHE_DIR"); c) {
cache = c;
} else {
cache =
std::filesystem::temp_directory_path() / "mlx" / version() / "ptx";
bool get_ptx_cache_dir(std::filesystem::path* result) {
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
if (!std::filesystem::is_directory(path)) {
std::error_code error;
if (!std::filesystem::create_directories(path, error)) {
return false;
}
if (!std::filesystem::exists(cache)) {
std::error_code error;
if (!std::filesystem::create_directories(cache, error)) {
return std::filesystem::path();
}
}
return cache;
}();
return cache;
}
*result = path;
return true;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
@@ -103,10 +75,6 @@ bool read_cached_ptx(
const std::string& module_name,
std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error);
@@ -136,12 +104,7 @@ void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
return;
}
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size());
@@ -150,9 +113,6 @@ void write_cached_ptx(
for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl;
}
std::ofstream source_file(cache_dir / (module_name + ".cu"));
source_file << source_code;
}
// Return if |device|'s version is not newer than |major|.|minor| version.
@@ -192,7 +152,7 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "complex.cuh",
INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh",
INCLUDE_PREFIX "indexing.cuh",
INCLUDE_PREFIX "scatter_ops.cuh",
@@ -208,7 +168,7 @@ constexpr const char* g_headers[] = {
jit_source_binary_ops,
jit_source_cast_op,
jit_source_config,
jit_source_complex,
jit_source_cucomplex_math,
jit_source_fp16_math,
jit_source_indexing,
jit_source_scatter_ops,
@@ -224,9 +184,11 @@ JitModule::JitModule(
const std::string& module_name,
const KernelBuilder& builder) {
// Check cache.
std::filesystem::path cache_dir;
std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
if (!get_ptx_cache_dir(&cache_dir) ||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
// Create program.
auto [source_code, kernel_names] = builder();
nvrtcProgram prog;
@@ -245,24 +207,16 @@ JitModule::JitModule(
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
std::string include = fmt::format("--include-path={}/include", cuda_home());
const char* args[] = {compute.c_str(), include.c_str()};
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
nvrtcCompileProgram(prog, std::size(args), args);
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
@@ -292,8 +246,7 @@ JitModule::JitModule(
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
}
// Load module.
@@ -311,13 +264,60 @@ JitModule::JitModule(
// Load kernels.
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
CHECK_CU_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel;
}
}
JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
CHECK_CU_ERROR(cuModuleUnload(module_));
}
void JitModule::launch_kernel(
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread) {
CUfunction kernel = get_kernel(kernel_name);
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
int _, block_dim;
CHECK_CU_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
if (block_dim > nthreads) {
block_dim = nthreads;
}
Dims num_blocks{1, 1, 1};
if (large) {
num_blocks =
get_2d_grid_dims_common(arr.shape(), arr.strides(), work_per_thread);
std::get<0>(num_blocks) =
(std::get<0>(num_blocks) + block_dim - 1) / block_dim;
} else {
std::get<0>(num_blocks) = (nthreads + block_dim - 1) / block_dim;
}
launch_kernel(stream, kernel, num_blocks, Dims{block_dim, 1, 1});
}
void JitModule::launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims) {
CHECK_CU_ERROR(cuLaunchKernel(
kernel,
std::get<0>(num_blocks),
std::get<1>(num_blocks),
std::get<2>(num_blocks),
std::get<0>(block_dims),
std::get<1>(block_dims),
std::get<2>(block_dims),
0,
stream,
args_.data(),
nullptr));
args_.clear();
storage_.clear();
}
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
@@ -329,16 +329,15 @@ CUfunction JitModule::get_kernel(const std::string& kernel_name) {
return it->second;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
void JitModule::append_ptr_arg(const void* v) {
args_.push_back(const_cast<void*>(v));
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder) {
auto& map = get_jit_module_cache();
static std::unordered_map<std::string, JitModule> map;
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first;

View File

@@ -4,7 +4,6 @@
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include <deque>
@@ -24,59 +23,6 @@ using KernelBuilderResult = std::pair<
/* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>;
struct KernelArgs {
void** args() {
return args_.data();
}
void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>()));
}
template <typename T>
void append(T val) {
storage_.emplace_back(val);
append_ptr(&storage_.back());
}
template <typename T>
void append(SmallVector<T> vec) {
storage_.emplace_back(std::move(vec));
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
vec.resize(NDIM);
append(std::move(vec));
}
void append_ptr(const void* v) {
args_.push_back(const_cast<void*>(v));
}
private:
std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store
// temporary values untill kernel is launched.
using Arg = std::variant<
std::monostate,
CUdeviceptr,
int32_t,
uint32_t,
int64_t,
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
std::deque<Arg> storage_;
};
class JitModule {
public:
JitModule(
@@ -87,14 +33,77 @@ class JitModule {
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
void append_arg(const array& a) {
append_arg(reinterpret_cast<CUdeviceptr>(a.data<void>()));
}
template <typename T>
void append_arg(T val) {
storage_.emplace_back(val);
append_ptr_arg(&storage_.back());
}
template <typename T>
void append_arg(std::vector<T> vec) {
if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null.
append_arg(std::monostate{});
} else {
append_ptr_arg(vec.data());
storage_.emplace_back(std::move(vec));
}
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim_arg(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
std::vector<T> copied(NDIM);
std::copy(vec.begin(), vec.end(), copied.data());
append_arg(std::move(copied));
}
// Launch kernel with |kernel_name| that each thread works on
// |work_per_thread| elements of |arr|.
void launch_kernel(
CUstream stream,
const std::string& kernel_name,
const array& arr,
bool large,
int work_per_thread = 1);
void launch_kernel(
CUstream stream,
CUfunction kernel,
Dims num_blocks,
Dims block_dims);
CUfunction get_kernel(const std::string& kernel_name);
private:
void append_ptr_arg(const void* v);
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
};
std::vector<void*> args_;
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
// The cuLaunchKernel API requires passing pointers to arguments so store
// temporary values untill kernel is launched.
using Arg = std::variant<
std::monostate,
CUdeviceptr,
int32_t,
uint32_t,
int64_t,
std::vector<const void*>,
std::vector<int32_t>,
std::vector<int64_t>>;
std::deque<Arg> storage_;
};
JitModule& get_jit_module(
const mlx::core::Device& device,

View File

@@ -30,25 +30,4 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@@ -6,12 +6,10 @@
#pragma once
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h>
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
@@ -19,46 +17,60 @@
namespace mlx::core {
template <typename F>
void dispatch_1_2_3(int n, F&& f) {
switch (n) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
// Convert a number between 1~3 to constexpr.
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
switch (N) { \
case 1: { \
constexpr int NDIM = 1; \
__VA_ARGS__; \
break; \
} \
case 2: { \
constexpr int NDIM = 2; \
__VA_ARGS__; \
break; \
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
}
}
template <typename F>
void dispatch_bool(bool v, F&& f) {
if (v) {
f(std::true_type{});
} else {
f(std::false_type{});
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
}
}
template <typename F>
void dispatch_block_dim(int threads, F&& f) {
if (threads <= WARP_SIZE) {
f(std::integral_constant<int, WARP_SIZE>{});
} else if (threads <= WARP_SIZE * 2) {
f(std::integral_constant<int, WARP_SIZE * 2>{});
} else if (threads <= WARP_SIZE * 4) {
f(std::integral_constant<int, WARP_SIZE * 4>{});
} else if (threads <= WARP_SIZE * 8) {
f(std::integral_constant<int, WARP_SIZE * 8>{});
} else if (threads <= WARP_SIZE * 16) {
f(std::integral_constant<int, WARP_SIZE * 16>{});
} else {
f(std::integral_constant<int, WARP_SIZE * 32>{});
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
{ \
uint32_t _num_threads = NUM_THREADS; \
if (_num_threads <= WARP_SIZE) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 2) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 4) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 8) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
}
}
// Maps CPU types to CUDA types.
template <typename T>
@@ -78,7 +90,7 @@ struct CTypeToCudaType<bfloat16_t> {
template <>
struct CTypeToCudaType<complex64_t> {
using type = cu::complex64_t;
using type = cuComplex;
};
template <typename T>
@@ -90,18 +102,14 @@ inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex numbers.
template <typename T>
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
cuda::std::is_same_v<T, complex128_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
inline constexpr bool is_inexact_v =
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@@ -120,19 +128,47 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
CHECK_CUDA_ERROR(cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1);
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"
#include "mlx/backend/gpu/copy.h"
@@ -10,6 +11,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -72,11 +75,9 @@ __global__ void layer_norm(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceT{block, temp}.Sum(sum);
@@ -87,18 +88,11 @@ __global__ void layer_norm(
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
normalizer += t * t;
}
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
@@ -107,15 +101,17 @@ __global__ void layer_norm(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
#pragma unroll
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(b, b_stride), bn, axis_size);
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
store_vector<N_READS>(out, index, xn, axis_size);
cub::StoreDirectBlocked(index, out, xn, axis_size);
}
}
@@ -148,11 +144,9 @@ __global__ void layer_norm_vjp(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
@@ -162,28 +156,19 @@ __global__ void layer_norm_vjp(
// Normalizer.
float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
@@ -195,10 +180,12 @@ __global__ void layer_norm_vjp(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
@@ -208,9 +195,9 @@ __global__ void layer_norm_vjp(
wn[i] = gi * xi;
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
cub::StoreDirectBlocked(index, gx, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
cub::StoreDirectBlocked(index, gw, wn, axis_size);
}
}
}
@@ -250,7 +237,8 @@ void LayerNorm::eval_gpu(
}
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
@@ -270,24 +258,22 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(w);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride,
b_stride);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
out.data<DataType>(),
eps_,
axis_size,
w_stride,
b_stride);
});
});
});
}
@@ -302,23 +288,21 @@ void LayerNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&s](const array& x, bool& copied) {
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
copied = false;
return x;
return {x, false};
}
copied = true;
return contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true};
};
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
bool copied;
auto x = check_input(inputs[0], copied);
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& b = inputs[2];
bool g_copied;
auto g = check_input(inputs[3], g_copied);
auto [g, g_copied] = check_input(inputs[3]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
@@ -349,59 +333,47 @@ void LayerNormVJP::eval_gpu(
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (has_w) {
if (!g_in_gx && donate_g) {
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes()));
encoder.add_temporary(gw_temp);
}
}
gw.set_data(allocator::malloc(gw.nbytes()));
gb.set_data(allocator::malloc(gb.nbytes()));
// The gradient for b in case we had a b.
bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size);
if (has_gb) {
// Finish with the gradient for b in case we had a b.
if (gb.ndim() == 1 && gb.size() == axis_size) {
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
col_reduce(encoder, g, gb, Reduce::ReduceType::Sum, {0}, plan);
}
// Insert dependency if `g` was donated
if ((g_in_gx || g_in_gw) && has_gb) {
encoder.set_input_array(gb);
}
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(g);
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::layer_norm_vjp<
DataType,
has_w_constant.value,
block_dim(),
N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
});
});
});

View File

@@ -43,19 +43,20 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
auto index = r * BLOCK_DIM + block.thread_rank();
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
prevmax = maxval;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
normalizer = normalizer + softmax_exp(vals[i] - maxval);
}
}
@@ -107,7 +108,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
@@ -141,19 +143,15 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
});
});
}

View File

@@ -1,159 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
namespace mlx::core {
template <
typename K,
typename V,
template <typename...> typename M = std::unordered_map>
class LRUCache {
public:
using value_type = std::pair<K, V>;
using list_type = std::list<value_type>;
using iterator = typename list_type::iterator;
using const_iterator = typename list_type::const_iterator;
using map_type = M<K, iterator>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {
if (capacity == 0) {
throw std::runtime_error("LRUCache requires capacity > 0.");
}
}
size_t size() const {
return map_.size();
}
size_t capacity() const {
return capacity_;
}
bool empty() const {
return vlist_.empty();
}
void resize(size_t new_capacity) {
capacity_ = new_capacity;
trim();
}
iterator begin() {
return vlist_.begin();
}
const_iterator begin() const {
return vlist_.begin();
}
iterator end() {
return vlist_.end();
}
const_iterator end() const {
return vlist_.end();
}
void clear() {
map_.clear();
vlist_.clear();
}
iterator find(const K& key) {
auto it = map_.find(key);
if (it == map_.end())
return end();
vlist_.splice(vlist_.begin(), vlist_, it->second);
return it->second;
}
template <typename U>
std::pair<iterator, bool> emplace(const K& key, U&& value) {
auto it = map_.find(key);
if (it != map_.end()) {
vlist_.splice(vlist_.begin(), vlist_, it->second);
return {it->second, false};
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
trim();
return {vlist_.begin(), true};
}
iterator erase(iterator pos) {
map_.erase(pos->first);
return vlist_.erase(pos);
}
V& operator[](const K& key) {
auto it = find(key);
if (it == end()) {
it = emplace(key, V{}).first;
}
return it->second;
}
private:
void trim() {
while (map_.size() > capacity_) {
auto last = std::prev(vlist_.end());
map_.erase(last->first);
vlist_.pop_back();
}
}
list_type vlist_;
map_type map_;
size_t capacity_;
};
// Turn a POD struct into a container key by doing bytes compare.
template <typename T>
struct BytesKey {
T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {}
BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
BytesKey(BytesKey&& other) {
memcpy(&pod, &other.pod, sizeof(T));
}
bool operator==(const BytesKey& other) const {
auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
return memcmp(ptr1, ptr2, sizeof(T)) == 0;
}
};
// Compute hash according to the bytes value of T.
template <typename T>
struct BytesHash {
static_assert(std::is_standard_layout_v<T>, "T is not POD");
size_t operator()(const T& pod) const {
auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < sizeof(T); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return value;
}
};
template <typename K, typename V>
using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
template <typename K, typename V>
using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
} // namespace mlx::core

View File

@@ -2,19 +2,274 @@
#include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/gemms/gemv.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <cublasLt.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <numeric>
namespace mlx::core {
namespace cu {
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
void check_cublas_error(const char* name, cublasStatus_t err) {
if (err != CUBLAS_STATUS_SUCCESS) {
// TODO: Use cublasGetStatusString when it is widely available.
throw std::runtime_error(
fmt::format("{} failed with code: {}.", name, static_cast<int>(err)));
}
}
class MatMul {
public:
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
sizeof(cublasOperation_t)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
uint64_t MiB = 1024 * 1024;
uint64_t workspace_size =
device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
pref_,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
sizeof(uint64_t)));
}
MatMul(
Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
uint64_t a_cols,
int64_t lda,
bool b_transposed,
uint64_t b_rows,
uint64_t b_cols,
int64_t ldb,
bool c_transposed,
int64_t ldc,
int32_t batch_count,
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: MatMul(
device,
dtype,
a_transposed,
a_rows,
a_cols,
lda,
b_transposed,
b_rows,
b_cols,
ldb,
batch_count,
a_batch_stride,
b_batch_stride) {
auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
}
~MatMul() {
cublasLtMatrixLayoutDestroy(a_desc_);
cublasLtMatrixLayoutDestroy(b_desc_);
cublasLtMatrixLayoutDestroy(c_desc_);
cublasLtMatrixLayoutDestroy(out_desc_);
cublasLtMatmulDescDestroy(matmul_desc_);
}
void run(
cu::CommandEncoder& encoder,
void* out,
void* a,
void* b,
void* c = nullptr,
float alpha = 1,
float beta = 0) {
if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
int ret = 0;
CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
encoder.device().lt_handle(),
matmul_desc_,
a_desc_,
b_desc_,
out_desc_,
out_desc_,
pref_,
1,
&heuristic_,
&ret));
if (ret == 0) {
throw std::runtime_error("Can not find algorithm for matmul.");
}
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
encoder.device().lt_handle(),
matmul_desc_,
&alpha,
a,
a_desc_,
b,
b_desc_,
&beta,
c ? c : out,
c ? c_desc_ : out_desc_,
out,
out_desc_,
&heuristic_.algo,
workspace.data<void>(),
workspace.nbytes(),
stream));
});
}
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUBLAS_COMPUTE_32F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
case float64:
case complex64:
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case float16:
return CUDA_R_16F;
case bfloat16:
return CUDA_R_16BF;
case float32:
return CUDA_R_32F;
case float64:
return CUDA_R_64F;
case complex64:
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in MatMul: {}.", dtype_to_string(dtype)));
}
}
cublasLtMatrixLayout_t create_matrix_layout(
cudaDataType_t type,
uint64_t rows,
uint64_t cols,
bool transposed,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order =
transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_stride,
sizeof(int64_t)));
}
return desc;
}
cublasLtMatmulDesc_t matmul_desc_{nullptr};
cublasLtMatmulPreference_t pref_{nullptr};
cublasLtMatrixLayout_t a_desc_{nullptr};
cublasLtMatrixLayout_t b_desc_{nullptr};
cublasLtMatrixLayout_t c_desc_{nullptr};
cublasLtMatrixLayout_t out_desc_{nullptr};
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace cu
namespace {
std::tuple<bool, int64_t, array>
check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
check_transpose(std::vector<array>& copies, const Stream& s, const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1 && stx == arr.shape(-1)) {
@@ -22,8 +277,9 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy = contiguous_copy_gpu(arr, s);
enc.add_temporary(arr_copy);
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy);
}
}
@@ -57,8 +313,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
@@ -79,26 +340,11 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
cu::device(s.device),
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
@@ -112,13 +358,17 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc);
a_it.step();
b_it.step();
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -129,7 +379,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto c = inputs[2];
auto& c_pre = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -140,25 +392,13 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
std::vector<array> copies;
auto [a_transposed, lda, a] = check_transpose(copies, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(copies, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(copies, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
for (auto& temp : copies) {
encoder.add_temporary(temp);
}
/////////////////////////////////////////////////////////////////////////////
@@ -186,8 +426,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
cu::device(s.device),
cu::MatMul matmul(
encoder.device(),
a.dtype(),
a_transposed,
M,
@@ -197,28 +437,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K,
N,
ldb,
c_transposed,
ldc,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc,
alpha_,
beta_);
a_it.step();
b_it.step();
c_it.step();
}
matmul.run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha_,
beta_);
}
} // namespace mlx::core

View File

@@ -1,55 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/primitives.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -0,0 +1,97 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/arange.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(stream),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
});
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(ArgPartition)
NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -1,331 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr float eps = 1e-7;
constexpr int simd_size = WARP_SIZE;
constexpr float n_bins = (1 << bits) - 1;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t in_index = offset * values_per_reduce;
if (in_index >= size) {
return;
}
size_t out_index = power_of_2_bits
? offset * writes_per_pack
: offset * bytes_per_pack / writes_per_reduce;
float w_thread[values_per_reduce];
float w_min = Limits<float>::max();
float w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
float val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
cg::greater<float> max_op;
cg::less<float> min_op;
auto warp = cg::tiled_partition<WARP_SIZE>(cg::this_thread_block());
w_min = cg::reduce(warp, w_min, min_op);
w_max = cg::reduce(warp, w_max, max_op);
float scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
float edge = side ? w_min : w_max;
float q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
float bias = at_zero ? 0 : edge;
// Write out the scales and biases
size_t gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = static_cast<T>(scale);
biases[gindex] = static_cast<T>(bias);
}
using OutType = std::conditional_t<bits == 5, uint64_t, uint32_t>;
OutType output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output |= val << (bits * (i % pack_factor));
}
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
out[out_index + i / pack_factor] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 1; j < writes_per_reduce; j++) {
uint8_t sval = warp.shfl_down(val, j);
output |= static_cast<OutType>(sval)
<< (bits * (j * values_per_reduce + i));
}
}
}
if constexpr (bits == 3 || bits == 6) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
}
} else if constexpr (bits == 5) {
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
out[out_index] = output & 0xff;
out[out_index + 1] = (output & 0xff00) >> 8;
out[out_index + 2] = (output & 0xff0000) >> 16;
out[out_index + 3] = (output & 0xff000000) >> 24;
out[out_index + 4] = (output & 0xff00000000) >> 32;
}
} else {
if constexpr (writes_per_reduce > 0) {
if (out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
}
}
template <typename T, int group_size, int bits>
__global__ void affine_dequantize(
const uint8_t* w,
const T* scales,
const T* biases,
T* out,
size_t size) {
auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x =
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
constexpr int pack_factor = get_pack_factor<bits, 8>();
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
size_t offset = tidx + grid_dim_x * size_t(tidy);
size_t oindex = offset * pack_factor;
if (oindex >= size) {
return;
}
size_t gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
out += oindex;
if constexpr (bits == 3) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x7) * scale + bias;
out[1] = static_cast<T>((w[0] & 0x38) >> 3) * scale + bias;
out[2] = (static_cast<T>((w[0] & 0xc0) >> 6) +
static_cast<T>((w[1] & 0x1) << 2)) *
scale +
bias;
out[3] = static_cast<T>((w[1] & 0xe) >> 1) * scale + bias;
out[4] = static_cast<T>((w[1] & 0x70) >> 4) * scale + bias;
out[5] = (static_cast<T>((w[1] & 0x80) >> 7) +
static_cast<T>((w[2] & 0x3) << 1)) *
scale +
bias;
out[6] = static_cast<T>((w[2] & 0x1c) >> 2) * scale + bias;
out[7] = static_cast<T>((w[2] & 0xe0) >> 5) * scale + bias;
} else if constexpr (bits == 5) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x1f) * scale + bias;
out[1] = (static_cast<T>((w[0] & 0xe0) >> 5) +
static_cast<T>((w[1] & 0x3) << 3)) *
scale +
bias;
out[2] = static_cast<T>((w[1] & 0x7c) >> 2) * scale + bias;
out[3] = (static_cast<T>((w[1] & 0x80) >> 7) +
static_cast<T>((w[2] & 0xf) << 1)) *
scale +
bias;
out[4] = (static_cast<T>((w[2] & 0xf0) >> 4) +
static_cast<T>((w[3] & 0x1) << 4)) *
scale +
bias;
out[5] = static_cast<T>((w[3] & 0x3e) >> 1) * scale + bias;
out[6] = (static_cast<T>((w[3] & 0xc0) >> 6) +
static_cast<T>((w[4] & 0x7) << 2)) *
scale +
bias;
out[7] = static_cast<T>((w[4] & 0xf8) >> 3) * scale + bias;
} else if constexpr (bits == 6) {
w += offset * bytes_per_pack;
out[0] = static_cast<T>(w[0] & 0x3f) * scale + bias;
out[1] = (static_cast<T>((w[0] >> 6) & 0x03) +
static_cast<T>((w[1] & 0x0f) << 2)) *
scale +
bias;
out[2] = (static_cast<T>((w[1] >> 4) & 0x0f) +
static_cast<T>((w[2] & 0x03) << 4)) *
scale +
bias;
out[3] = static_cast<T>((w[2] >> 2) & 0x3f) * scale + bias;
} else {
uint val = w[offset];
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[i] = scale * static_cast<T>(d) + bias;
}
}
}
} // namespace cu
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate the number of elements per thread
int per_thread = group_size_ / WARP_SIZE;
size_t size = w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.size());
});
});
});
}
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
});
});
});
}
} // namespace mlx::core

View File

@@ -1,80 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
} // namespace
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -1,59 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More