Compare commits

...

87 Commits

Author SHA1 Message Date
Awni Hannun
8ce49cd39e fix quantized vjp for mxfp4 (#2555) 2025-08-29 10:06:15 -07:00
Awni Hannun
9c68b50853 version bump (#2554) 2025-08-29 06:54:17 -07:00
Awni Hannun
111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
827003d568 fix METAL quantization in JIT (#2553) 2025-08-28 18:26:25 -07:00
Awni Hannun
d363a76aa4 Bump xcode in circle (#2551)
* bump xcode in circle

* bump xcode in circle

* bump xcode in circle
2025-08-28 13:13:34 -07:00
Awni Hannun
70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5 [CUDA] fix sort (#2550)
* [CUDA] fix sort

* fix test
2025-08-27 19:48:43 -07:00
Cheng
31c6f6e33f [CUDA] Use ConcurrentContext in concatenate_gpu (#2549) 2025-08-28 09:30:08 +09:00
Awni Hannun
584d48458e link with nccl (#2546) 2025-08-27 10:01:07 -07:00
Cheng
5cf984ca87 Separate cpu compilation cache by versions (#2548) 2025-08-27 11:25:15 +09:00
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
Awni Hannun
3dcb286baf Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default

* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9 [CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives

* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Awni Hannun
2ca75bb529 Remove nccl install in release (#2542) 2025-08-25 15:20:18 -07:00
Awni Hannun
db14e29a0b allow pathlib.Path to save/load functions (#2541) 2025-08-25 14:58:49 -07:00
Awni Hannun
d2f540f4e0 Use nccl header only when nccl is not present (#2539)
* use nccl header only when nccl is not present

* larger machine for cuda build
2025-08-25 14:17:25 -07:00
Cheng
333ffea273 [CUDA] Remove thrust in arange (#2535) 2025-08-24 16:22:36 +09:00
Cheng
f55b6f1f2f Enable COMPILE_WARNING_AS_ERROR for linux builds in CI (#2534) 2025-08-24 15:33:08 +09:00
Awni Hannun
30561229c7 Fix allocation bug in NCCL (#2530) 2025-08-22 14:39:43 -07:00
Awni Hannun
068a4612e9 nccl default for backend=any (#2528)
* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
2025-08-22 12:24:27 -07:00
Andrey Portnoy
5722c147de [CUDA] Update calls to cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 (#2525)
* [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13

These functions' signatures changed in CUDA 13, so we differentiate
between CUDA 13 and preceding releases at compile time.

* Mention NVIDIA in ACKNOWLEDGMENTS.md
2025-08-21 19:57:20 -07:00
Cheng
f6819a1f26 Fix warning 186-D from nvcc (#2527) 2025-08-22 10:29:55 +09:00
Awni Hannun
f93f87c802 nccl dep + default for cuda (#2526) 2025-08-21 17:57:49 -07:00
Anastasiia Filippova
9392fc3f88 NCCL backend (#2476) 2025-08-21 11:56:15 -07:00
Awni Hannun
e843c4d8d5 fix power (#2523) 2025-08-21 06:46:01 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205 Fix overflow in large filter small channels (#2520) 2025-08-20 08:03:29 -07:00
russellizadi
512281781c Remove state return from function example in compile documentation (#2518) 2025-08-20 00:45:05 -07:00
Cheng
ac85ddfdb7 [CUDA] Add GEMM-based fallback convolution kernels (#2511)
* Add gemm_conv

* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Cheng
65d0d40232 Split cuDNN helpers into a separate header (#2491)
* Add RAII managed CudaGraph class

* Implement forward rms_norm with cuDNN

* Revert back to old rms norm kernel
2025-08-20 09:29:28 +09:00
Awni Hannun
cea9369610 fix lapack svd (#2515) 2025-08-18 15:07:59 -07:00
Awni Hannun
e7c6e1db82 no segfault with uninitialized array.at (#2514) 2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b fix custom kernel test (#2510) 2025-08-18 06:45:59 -07:00
Angelos Katharopoulos
1df9887998 Ensure no oob read in gemv_masked (#2508) 2025-08-17 08:42:33 -07:00
Angelos Katharopoulos
73f22d6226 Ensure small sort doesn't use indices if not argsort (#2506) 2025-08-17 08:42:20 -07:00
Cheng
c422050ca7 Update cuDNN Frontend to v1.14 (#2505) 2025-08-17 19:13:01 +09:00
Cheng
1ba18ff7d9 [CUDA] Fix conv grads with groups (#2495)
* Put reshape utils in one file

* [CUDA] Fix conv grads with groups

* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Cheng
37b440faa8 Clean up code handling both std::vector and SmallVector (#2493) 2025-08-16 09:01:10 +09:00
Cheng
888b13ed63 Remove the hack around SmallVector in cpu compile (#2494) 2025-08-16 08:17:24 +09:00
Cheng
4abb218d21 The naive_conv_2d is no longer used (#2496) 2025-08-16 07:57:30 +09:00
Awni Hannun
6441c21a94 Faster general unary op (#2472)
* faster general unary op

* faster general ops + reorg

* fix + comment

* binary two

* copy general
2025-08-15 15:04:12 -07:00
Cheng
dfb5022eab Rename cu::Matmul to CublasGemm (#2488) 2025-08-13 09:37:40 +09:00
Daniel Yeh
ac207ce7aa make code blocks copyable (#2480)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-08-12 12:29:02 -07:00
Abe Leininger
fce53b61d6 Fix reduce sum/prod overflow (#2477) 2025-08-12 00:05:33 -07:00
Angelos Katharopoulos
8ae4a76308 Use CMake <4.1 to avoid the nvpl error (#2489) 2025-08-12 00:03:42 -07:00
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
Awni Hannun
56be773610 version (#2470) 2025-08-07 00:36:04 -07:00
Jagrit Digani
a9bdd67baa Add CUDA sdpa vector (#2468) 2025-08-06 21:40:26 -07:00
Angelos Katharopoulos
f2adb5638d Fix typo in metal command encoder (#2471) 2025-08-06 16:58:23 -07:00
Luca Vivona
728d4db582 Support destination arg in tree flatten/unflatten (#2450) 2025-08-06 15:34:59 -07:00
Awni Hannun
db5c7efcf6 revert default cuda install (#2465)
* revert default cuda install

* revert default cuda install
2025-08-06 06:19:12 -07:00
Awni Hannun
7bb96e4249 fix cublas on h100 (#2466) 2025-08-06 06:18:58 -07:00
Awni Hannun
fa89f0b150 faster gather qmm sorted test (#2463) 2025-08-05 06:27:40 -07:00
Awni Hannun
ca973d1e83 fix install tags (#2464) 2025-08-04 20:01:23 -07:00
Cheng
828c5f1137 Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
2025-08-05 09:41:03 +09:00
Gaétan Lepage
7d86a5c108 Feat: add USE_SYSTEM_FMT CMake option (#2219) 2025-08-04 16:36:11 -07:00
Awni Hannun
0b807893a7 fix wraps compile (#2461) 2025-08-04 16:14:18 -07:00
Awni Hannun
6ad0889c8a default install cuda on linux (#2462) 2025-08-04 15:33:05 -07:00
Zamderax
737dd6d1ac Add missing <algorithm> header to jit_compiler.cpp (#2460)
Fixes compilation error on Linux where std::find_if is used on line 121
but the <algorithm> header was not included. While this might work on
some platforms due to transitive includes, it's not guaranteed by the
C++ standard.

Resolves issue #2459
2025-08-04 14:00:46 -07:00
Cheng
aaf78f4c6b Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph

* Remove unused destructor
2025-08-02 21:28:57 +09:00
Angelos Katharopoulos
8831064493 Fix arctan2 grads (#2453) 2025-08-01 21:06:04 -07:00
Angelos Katharopoulos
be9bc96da4 [CUDA] Matmul utils initial commit (#2441) 2025-08-01 14:22:25 -07:00
Angelos Katharopoulos
86258f292f [CUDA] Vectorize generated kernels (#2444) 2025-07-31 18:18:57 -07:00
Cheng
b26d88591c [CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
2025-08-01 10:16:06 +09:00
Cheng
86c6a15571 [CUDA] Backward convolution (#2431) 2025-08-01 09:54:05 +09:00
junpeiz
8b25ce62d5 Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export.

* Skip quantization related test for CUDA backend.
2025-07-31 11:06:26 -07:00
Awni Hannun
da5912e4f2 fix custom metal extension (#2446) 2025-07-31 06:25:36 -07:00
Cheng
daafee676f Fix wrong graph key when using concurrent context (#2447) 2025-07-31 06:01:05 -07:00
Awni Hannun
d32519c8ee fix gemv regression (#2445) 2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249 fix circular reference (#2443) 2025-07-30 09:37:44 -07:00
Angelos Katharopoulos
3bf81ed1bd [CUDA] Quantized refactoring (#2442) 2025-07-30 08:27:20 -07:00
Cheng
2204182bba Make CI faster (#2440) 2025-07-30 02:26:36 -07:00
Cheng
3628e5d497 Use load_vector in arg_reduce (#2439) 2025-07-30 17:40:26 +09:00
Cheng
a0ae49d397 Move arange to its own file (#2438) 2025-07-30 13:05:51 +09:00
Cheng
254476718b Remove the kernel arg from get_launch_args (#2437) 2025-07-30 11:43:02 +09:00
Awni Hannun
3adba92ebe Cuda faster softmax (#2435)
* faster softmax and logsumexp

* faster softmax and logsumexp

* format
2025-07-29 17:18:12 -07:00
Awni Hannun
ef631d63af faster rms norm (#2433) 2025-07-29 13:12:00 -07:00
Cheng
970dbe8e25 Use ccache in CI (#2414)
* Detect ccache

* Use ccache in CI

* Separate cache for different images

* Test both 12.2 and 12.9 for PRs
2025-07-29 08:43:22 +09:00
Awni Hannun
641be9463b Add more CUDA architectures for PyPi package (#2427)
* add cuda sm 90

* add more archs
2025-07-28 12:35:15 -07:00
Awni Hannun
ab0e608862 [CUDA] More sizes for gemv (#2429)
* route more to gemv

* route more sizes to custom gemv
2025-07-28 12:35:01 -07:00
Awni Hannun
1588659062 no occupancy query for launch params (#2426) 2025-07-28 09:09:41 -07:00
Awni Hannun
b9e88fb976 [CUDA] Fix segfault on exit (#2424)
* fix cuda segfault on exit

* comment
2025-07-27 08:08:13 -07:00
264 changed files with 13393 additions and 3487 deletions

View File

@@ -18,13 +18,14 @@ jobs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install doxygen
python3.9 -m venv env
@@ -81,23 +82,25 @@ jobs:
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get upgrade -y
pip install --upgrade cmake
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
pip install -e ".[dev]"
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
@@ -105,6 +108,7 @@ jobs:
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
@@ -116,7 +120,7 @@ jobs:
parameters:
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
@@ -124,39 +128,37 @@ jobs:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
source env/bin/activate
uv venv --python 3.9
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
@@ -165,16 +167,17 @@ jobs:
- run:
name: Build example extension
command: |
source env/bin/activate
source .venv/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
@@ -183,7 +186,7 @@ jobs:
- run:
name: Build small binary
command: |
source env/bin/activate
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
@@ -195,12 +198,13 @@ jobs:
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
@@ -212,22 +216,56 @@ jobs:
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install Python package
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source env/bin/activate
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
@@ -236,7 +274,7 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
build_env:
type: string
default: ""
@@ -245,7 +283,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
@@ -253,11 +291,15 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -267,19 +309,19 @@ jobs:
- run:
name: Install Python package
command: |
source env/bin/activate
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
conda activate env
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
@@ -289,7 +331,7 @@ jobs:
- run:
name: Build common package
command: |
source env/bin/activate
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
@@ -298,7 +340,7 @@ jobs:
- run:
name: Upload package
command: |
source env/bin/activate
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
@@ -323,14 +365,10 @@ jobs:
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get upgrade -y
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo apt-get install -y apt-utils
sudo apt-get install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
@@ -375,7 +413,7 @@ jobs:
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
resource_class: xlarge
steps:
- checkout
- run:
@@ -422,7 +460,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
@@ -447,68 +485,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
@@ -550,11 +527,14 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
@@ -566,53 +546,7 @@ workflows:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
@@ -631,68 +565,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:

View File

@@ -25,6 +25,11 @@ MLX was developed with contributions from the following individuals:
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -41,7 +41,9 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests -------------------------
message(
@@ -68,6 +70,15 @@ else()
set(MLX_BUILD_METAL OFF)
endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
endif()
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -129,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
@@ -232,12 +249,16 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
if(USE_SYSTEM_FMT)
find_package(fmt REQUIRED)
else()
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)

View File

@@ -78,13 +78,13 @@ pip install mlx
To install the CUDA backend on Linux, run:
```bash
pip install "mlx[cuda]"
pip install mlx[cuda]
```
To install a CPU-only Linux package, run:
```bash
pip install "mlx[cpu]"
pip install mlx[cpu]
```
Checkout the

54
cmake/FindNCCL.cmake Normal file
View File

@@ -0,0 +1,54 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -1,4 +1,5 @@
sphinx
breathe
sphinx-book-theme
sphinx-copybutton
mlx

View File

@@ -18,6 +18,7 @@ release = version
# -- General configuration ---------------------------------------------------
extensions = [
"sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",

View File

@@ -127,7 +127,8 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
source=source,
ensure_row_contiguous=False,
)
def exp_elementwise(a: mx.array):
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]

View File

@@ -394,14 +394,14 @@ below.
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
std::stream kname;
kname = "axpby_general_" + type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext");
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib);
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers

View File

@@ -30,7 +30,7 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install "mlx[cuda]"
pip install mlx[cuda]
To install the CUDA package from PyPi your system must meet the following
requirements:
@@ -49,7 +49,7 @@ For a CPU-only version of MLX that runs on Linux use:
.. code-block:: shell
pip install "mlx[cpu]"
pip install mlx[cpu]
To install the CPU-only package from PyPi your system must meet the following
requirements:
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag

9
docs/src/python/cuda.rst Normal file
View File

@@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,3 +13,4 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For

View File

@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]

View File

@@ -7,17 +7,17 @@ Exporting Functions
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
params = tree_flatten(model.parameters(), destination={})
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
saved once.
Transformations with Imported Functions
---------------------------------------
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2025 Apple Inc.
#include <dlfcn.h>
#include <iostream>
#include <sstream>
@@ -16,6 +17,19 @@
namespace my_ext {
// A helper function to find the location of the current binary on disk.
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
std::string current_binary_dir() {
static std::string binary_dir = []() {
Dl_info info;
if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
throw std::runtime_error("Unable to get current binary dir.");
}
return std::filesystem::path(info.dli_fname).parent_path().string();
}();
return binary_dir;
}
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
}
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
std::string kname = "axpby_";
kname += (contiguous_kernel ? "contiguous_" : "general_");
kname += type_to_name(out);
// Load the metal library
auto lib = d.get_library("mlx_ext");
auto lib = d.get_library("mlx_ext", current_binary_dir());
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), lib);
auto kernel = d.get_kernel(kname, lib);
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.25
mlx>=0.21.0
nanobind==2.2.0
nanobind==2.4.0

View File

@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c shape: {c_cpu.shape}")
print(f"c dtype: {c_cpu.dtype}")
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")

View File

@@ -10,6 +10,7 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@@ -18,8 +19,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc

View File

@@ -197,7 +197,7 @@ void shared_buffer_reshape(
array& out);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}

View File

@@ -15,6 +15,7 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -94,7 +95,11 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir = std::filesystem::temp_directory_path();
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
@@ -157,10 +162,12 @@ inline void build_kernel(
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
os << "void " << kernel_name
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
@@ -175,8 +182,8 @@ inline void build_kernel(
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
os << " const int64_t* " << xname << "_strides = strides["
<< strides_index++ << "];" << std::endl;
}
}
@@ -186,10 +193,8 @@ inline void build_kernel(
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
// Add output size
if (contiguous) {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
@@ -290,7 +295,6 @@ void Compiled::eval_cpu(
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
@@ -298,9 +302,6 @@ void Compiled::eval_cpu(
const auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
args.push_back(strides[strides_index++].data());
}
}
// Get the kernel name from the lib
@@ -335,16 +336,20 @@ void Compiled::eval_cpu(
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
if (!contiguous) {
args.push_back((void*)shape.data());
} else {
if (contiguous) {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
encoder.dispatch([fun,
args = std::move(args),
strides = std::move(strides),
shape = std::move(shape)]() mutable { fun(args.data()); });
shape = std::move(shape)]() mutable {
SmallVector<int64_t*> strides_ptrs;
for (auto& s : strides) {
strides_ptrs.push_back(s.data());
}
fun(shape.data(), strides_ptrs.data(), args.data());
});
}
} // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>

View File

@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)

View File

@@ -1,7 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -13,6 +11,35 @@ namespace mlx::core {
namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
@@ -407,6 +434,231 @@ void _qmm_dispatch(
}
}
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
@@ -513,115 +765,198 @@ void _bs_qmm_dispatch(
}
}
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& lhs_indices = inputs[inputs.size() - 2];
auto& rhs_indices = inputs[inputs.size() - 1];
std::vector<array> temps;
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) {
&encoder](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
}
template <typename T, typename U>
@@ -705,7 +1040,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
void fast::Quantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -764,7 +1099,7 @@ void fast::AffineQuantize::eval_cpu(
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
}
});
}

View File

@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
switch (in.dtype()) {
case bool_:
case uint8:
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:

View File

@@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
@@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
// Raising an integer to a negative power is undefined
if (any(exp < 0)) {
return 0;
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -8,7 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -333,47 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
sort<MLX_GET_TYPE(type_tag)>(out, axis);
});
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -81,9 +81,7 @@ void svd_impl(
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
@@ -91,30 +89,20 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
@@ -136,20 +124,13 @@ void svd_impl(
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
@@ -167,13 +148,6 @@ void svd_impl(
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
});
encoder.add_temporary(in);

View File

@@ -6,8 +6,8 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
@@ -16,8 +16,13 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
@@ -29,7 +34,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
@@ -38,22 +43,26 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
else()
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
endif()
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
@@ -105,11 +114,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
@@ -145,7 +154,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_TAG v1.14.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -30,8 +30,15 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = 0;
#else
int loc = 0;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {

View File

@@ -0,0 +1,69 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, typename IdxT, int N_WRITES>
__global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
}
}
} // namespace cu
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
constexpr int N_WRITES = 16 / sizeof(OutType);
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
encoder.add_kernel_node(
cu::arange<OutType, IdxT, N_WRITES>,
num_blocks,
block_dims,
0,
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
});
}
} // namespace mlx::core

View File

@@ -44,8 +44,11 @@ struct ArgMin {
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
@@ -74,8 +77,11 @@ struct ArgMax {
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
@@ -166,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel,
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
out.size(),

View File

@@ -0,0 +1,21 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Add)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(ArcTan2)
} // namespace mlx::core

View File

@@ -28,7 +28,7 @@ __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]);
out_vec[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -49,7 +49,7 @@ __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
out_vec[i] = Op{}(a[0], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -70,7 +70,7 @@ __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
out_vec[i] = Op{}(a_vec[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -92,46 +92,96 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -209,36 +259,61 @@ void binary_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto kernel = cu::binary_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@@ -248,8 +323,7 @@ void binary_op_gpu_inplace(
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -259,16 +333,12 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
@@ -306,54 +376,4 @@ void binary_op_gpu(
binary_op_gpu<cu::func>(inputs, out, name(), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Divide)
} // namespace mlx::core

View File

@@ -0,0 +1,15 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Greater)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(GreaterEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Less)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LessEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogAddExp)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalAnd)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(LogicalOr)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Maximum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Minimum)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Multiply)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(NotEqual)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Power)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Remainder)
} // namespace mlx::core

View File

@@ -0,0 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/binary/binary.cuh"
namespace mlx::core {
BINARY_GPU(Subtract)
} // namespace mlx::core

View File

@@ -33,8 +33,8 @@ binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -60,9 +60,9 @@ binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
auto out = Op{}(a[0], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -88,9 +88,9 @@ binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
auto out = Op{}(a_vec[i], b[0]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -117,9 +117,9 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
auto out = Op{}(a_vec[i], b_vec[i]);
out_a_vec[i] = out[0];
out_b_vec[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
@@ -127,45 +127,99 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
template <
typename Op,
typename In,
typename Out,
typename IdxT,
int NDIM,
int N_READS>
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto a_stride_x = a_strides[NDIM - 1];
auto b_stride_x = b_strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data());
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out, typename IdxT>
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto a_stride_x = a_strides[ndim - 1];
auto b_stride_x = b_strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [a_idx, b_idx] = elem_to_loc(
index_rest * shape_x,
shape.data(),
a_strides.data(),
b_strides.data(),
ndim);
auto a_vec =
load_vector<N_READS>(a + a_idx, index_x, shape_x, a_stride_x, In(0));
auto b_vec =
load_vector<N_READS>(b + b_idx, index_x, shape_x, b_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec_a;
AlignedVector<Out, N_READS> out_vec_b;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec[i], b_vec[i]);
out_vec_a[i] = out[0];
out_vec_b[i] = out[1];
}
store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x);
store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x);
}
template <typename Op, typename In, typename Out>
@@ -225,6 +279,17 @@ void binary_two_op_gpu_inplace(
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out_a.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
@@ -232,35 +297,46 @@ void binary_two_op_gpu_inplace(
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
dims_constant(),
1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant(),
4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::binary_two_g<Op, InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
rest,
const_param(shape),
const_param(a_strides),
const_param(b_strides),
@@ -270,8 +346,7 @@ void binary_two_op_gpu_inplace(
} else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
@@ -281,7 +356,6 @@ void binary_two_op_gpu_inplace(
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
@@ -291,6 +365,7 @@ void binary_two_op_gpu_inplace(
kernel,
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),

View File

@@ -104,10 +104,41 @@ struct FusedKernelBuilder {
" }\n";
}
// Vectorized read loop
if (contiguous) {
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
if (is_scalar(x) || is_constant(i)) {
continue;
}
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
xname,
type);
}
}
// Create some space for the outputs
for (const auto& x : outputs) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
os += fmt::format(
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
}
// Work loop
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
if (!contiguous) {
os +=
"\n"
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
} else {
os +=
"\n"
" #pragma unroll\n"
" for (int i = 0; i < work_per_thread; i++) {\n";
}
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
@@ -122,7 +153,7 @@ struct FusedKernelBuilder {
} else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname);
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
value = fmt::format("vec_{}[i]", xname);
} else {
value = fmt::format("{}[{}_idx]", xname, xname);
}
@@ -150,25 +181,30 @@ struct FusedKernelBuilder {
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
}
// End of work loop
os +=
"\n"
" index++;\n";
if (!contiguous) {
os += "\n";
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
if (is_scalar(x) || is_constant(i)) {
continue;
}
os += " " + xname + "_idx += " + xname + "_strides[NDIM - 1];\n";
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
}
}
os += " }\n";
// Store the output to global memory
for (const auto& x : outputs) {
os += fmt::format(
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
namer.get_name(x));
}
os += "}\n";
}
};
@@ -192,6 +228,15 @@ void Compiled::eval_gpu(
nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream();
// Determine the work per thread for the vectorized reads/writes. We take it
// as 16 over the max itemsize for the outputs. Another heuristic could be
// over the max itemsize of all arrays.
int max_size = 1;
for (const auto& x : outputs) {
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
}
int work_per_thread = 16 / max_size;
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code.
cu::FusedKernelBuilder builder{
@@ -205,29 +250,25 @@ void Compiled::eval_gpu(
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names;
for (auto work_per_thread : std::array<int, 2>{1, 4}) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
lib_name(),
work_per_thread));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_contiguous<int64_t, {}>",
lib_name(),
work_per_thread));
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t, {}>",
lib_name(),
i,
work_per_thread));
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, int64_t, {}>",
lib_name(),
i,
work_per_thread));
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
}
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
return std::make_tuple(
false, std::move(builder.os), std::move(kernel_names));
});
// Collapse contiguous dims to route to a faster kernel if possible. Also
@@ -269,7 +310,6 @@ void Compiled::eval_gpu(
}
// Choose work per thread
int work_per_thread = 4;
if (!contiguous && shape.back() % work_per_thread != 0) {
work_per_thread = 1;
}
@@ -294,8 +334,8 @@ void Compiled::eval_gpu(
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] =
get_launch_args(kernel, outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
} // namespace mlx::core

View File

@@ -1,257 +1,214 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
// cudnn_frontend.h redefines this macro.
#undef CHECK_CUDA_ERROR
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <cassert>
#include <numeric>
namespace mlx::core {
namespace {
// Not all engines support it so can not use this API now.
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
struct ConvCacheKey {
int device_id;
cudnnBackendDescriptorType_t backend_type;
cudnnDataType_t cudnn_type;
cudnnDataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> filter_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> padding_lo;
std::array<int, MAX_NDIM> padding_hi;
std::array<int, MAX_NDIM> stride;
std::array<int, MAX_NDIM> dilation;
int groups;
bool flip;
uint8_t input_alignment;
uint8_t filter_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
};
auto& conv_cache() {
static LRUBytesKeyCache<ConvCacheKey, cudnn_frontend::ExecutionPlan> cache(
/* capacity */ 128);
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
return cache;
}
template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
array& x,
array& w,
array& y,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding_lo_,
const std::vector<int>& padding_hi_,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
template <typename T>
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
if (backend_type == CONV_BACKWARD_INPUT) {
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
convert_vector<int64_t>(input_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
padding_hi = padding_lo;
return std::make_tuple(
convert_vector<int64_t>(kernel_dilation),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_strides));
} else {
return std::make_tuple(
convert_vector<int64_t>(kernel_strides),
std::move(padding_lo),
std::move(padding_hi),
convert_vector<int64_t>(kernel_dilation));
}
return alignment;
}
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
cudnn_frontend::EngineConfigList get_engine_configs(
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = false) {
cudnn_frontend::GeneratorSource source;
if (use_fallback) {
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
};
} else {
source = [](cudnn_frontend::OperationGraph& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
};
}
cudnn_frontend::EngineConfigGenerator generator(1, &source);
auto configs = generator.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
bool execute_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
const array& in,
const array& wt,
array& out) {
int workspace_size = plan.getWorkspaceSize();
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
int64_t uids[3] = {'x', 'w', 'y'};
void* data_ptrs[3] = {
const_cast<void*>(in.data<void>()),
const_cast<void*>(wt.data<void>()),
out.data<void>(),
};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
cudaGraph_t graph;
cudaGraphCreate(&graph, 0);
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
if (cudnnBackendPopulateCudaGraph(
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
encoder.add_graph_node(graph);
#else
auto capture = encoder.capture_context();
if (cudnnBackendExecute(
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
#endif
encoder.add_temporary(workspace);
return true;
}
bool try_engines(
cu::CommandEncoder& encoder,
cudnn_frontend::EngineConfigList& configs,
const ConvCacheKey& cache_key,
const std::string& op_graph_tag,
const array& in,
const array& wt,
array& out) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(encoder.device().cudnn_handle())
.setEngineConfig(config, op_graph_tag)
.build();
if (execute_plan(encoder, plan, in, wt, out)) {
conv_cache().emplace(cache_key, std::move(plan));
return true;
}
} catch (cudnn_frontend::cudnnException&) {
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
return std::nullopt;
}
return false;
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out.size() == 0) {
return;
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
array group_transpose(
const array& x,
int groups,
int group_dim,
int axis1,
int axis2,
Stream s) {
if (groups == 1) {
return swapaxes_in_eval(x, axis1, axis2);
}
int ndim = x.ndim();
if (group_dim < 0) {
group_dim += ndim;
}
if (axis1 < 0) {
axis1 += ndim;
}
if (axis2 < 0) {
axis2 += ndim;
}
if (group_dim <= axis1) {
axis1 += 1;
}
if (group_dim <= axis2) {
axis2 += 1;
}
auto shape = x.shape();
shape.insert(shape.begin() + group_dim, groups);
shape[group_dim + 1] = shape[group_dim + 1] / groups;
array x_trans = reshape_in_eval(x, std::move(shape), s);
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
return x_trans;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Do necessary transposes and copies to prepare the inputs and outputs for
// building the cuDNN conv op. It is safe to be called multiple times in one
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array in,
array wt,
array out,
int groups,
Stream s) {
// Transpose the args depending on the backend type.
// TODO: Handle groups.
if (backend_type == CONV_BACKWARD_INPUT) {
wt = group_transpose(wt, groups, 0, 0, -1, s);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
in = group_transpose(in, groups, -1, 0, -1, s);
wt = swapaxes_in_eval(wt, 0, -1);
// Create a contiguous array that shares the data with |out|, but with dim
// C_in and C_out swapped.
Shape shape(out.shape());
std::swap(shape.front(), shape.back());
Strides strides(shape.size(), 1);
for (int i = shape.size() - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
array intermediate(std::move(shape), out.dtype(), nullptr, {});
intermediate.copy_shared_buffer(
out, std::move(strides), {true, true, false}, out.data_size());
out = intermediate;
}
// cuDNN requires contiguous input.
// TODO: Handle NCHW format specially.
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
@@ -261,80 +218,201 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporary(wt);
}
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& intermediate_out,
array& final_out) {
encoder.set_input_array(in);
encoder.set_input_array(wt);
encoder.set_output_array(out);
encoder.set_output_array(final_out);
auto backend_type = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
auto cudnn_type = dtype_to_cudnn_type(in.dtype());
if (backend_type == CONV_BACKWARD_WEIGHT) {
// Turn |out| into a strided array, which will have C_in and C_out swapped
// in vjp and the final |grad_weight| will then be contiguous.
Strides strides = intermediate_out.strides();
std::swap(strides.front(), strides.back());
final_out.copy_shared_buffer(
intermediate_out,
std::move(strides),
{false, false, false},
intermediate_out.data_size());
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
nvtx3::scoped_range r("Convolution::eval_gpu");
if (out_.size() == 0) {
return;
}
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
encoder.device().cuda_device(),
backend_type,
cudnn_type,
fixed_vector(in.shape()),
fixed_vector(wt.shape()),
fixed_vector(padding_lo_),
fixed_vector(padding_hi_),
fixed_vector(kernel_strides_),
fixed_vector(kernel_dilation_),
dtype_to_cudnn_type(dtype),
vector_key(in.shape()),
vector_key(wt.shape()),
vector_key(kernel_strides_),
vector_key(padding_lo_),
vector_key(padding_hi_),
vector_key(kernel_dilation_),
groups_,
flip_,
get_alignment(in),
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
if (!execute_plan(encoder, it->second, in, wt, out)) {
throw std::runtime_error("Cached convolution plan failed to execute.");
auto& [backend_type, plan] = it->second;
if (plan) {
// Run cached plan.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
} else {
// Run fallback kernel.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
}
return;
}
// Build operation graph.
auto compute_data_type = (in.dtype() == float16 || in.dtype() == bfloat16)
? CUDNN_DATA_FLOAT
: cudnn_type;
auto stride = convert_vector<int64_t>(kernel_strides_);
auto padding_lo = convert_vector<int64_t>(padding_lo_);
auto padding_hi = convert_vector<int64_t>(padding_hi_);
auto dilation = convert_vector<int64_t>(kernel_dilation_);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_tensor('x', in))
.setwDesc(build_tensor('w', wt))
.setyDesc(build_tensor('y', out))
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
// Try to run plans based on heuristics.
auto configs = get_engine_configs(backend_type, in.dtype(), op_graph);
auto op_graph_tag = op_graph.getTag();
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return;
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
} else {
// Otherwise it could be backward weight convolution or forward convolution,
// mathematically there is no difference so we have to use heuristics.
// Empirically backward convolutions have large kernel dimensions, and
// usually have |in| and |wt| transposed.
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
wt.shape(2) > out.shape(2)) {
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
} else {
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
}
}
// Then try fallback plans.
configs = get_engine_configs(backend_type, in.dtype(), op_graph);
if (try_engines(encoder, configs, cache_key, op_graph_tag, in, wt, out)) {
return;
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend,
x,
w,
y,
kernel_strides_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_conv_op_graph(
encoder,
try_backend,
dtype,
x,
w,
y,
stride,
padding_lo,
padding_hi,
dilation);
if (op_graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
break;
}
}
throw std::runtime_error("Unable to find an engine for convolution.");
if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) {
throw std::runtime_error("[conv] Unable to find an execution plan.");
}
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
// Use fallback kernel for settings not supported by cuDNN.
gemm_conv(
encoder,
in,
wt,
out,
kernel_strides_,
padding_lo_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
s);
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
}
} // namespace mlx::core

View File

@@ -0,0 +1,126 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
namespace mlx::core {
template <int NDIM>
struct ConvParams {
int N; // Batch size
int C; // In channels
int O; // Out channels
int strides[NDIM];
int padding[NDIM];
int kernel_dilation[NDIM];
int input_dilation[NDIM];
int groups;
bool flip;
int in_spatial_dims[NDIM];
int wt_spatial_dims[NDIM];
int out_spatial_dims[NDIM];
int64_t in_strides[NDIM + 2];
ConvParams(
const array& in,
const array& wt,
const array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip)
: N(in.shape(0)),
C(in.shape(-1)),
O(wt.shape(0)),
groups(groups),
flip(flip) {
std::copy_n(strides.begin(), NDIM, this->strides);
std::copy_n(padding.begin(), NDIM, this->padding);
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
}
};
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s);
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s);
inline void gemm_conv(
cu::CommandEncoder& encoder,
array in,
array wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
if (!in.flags().row_contiguous) {
in = contiguous_copy_gpu(in, s);
encoder.add_temporary(in);
}
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
encoder.add_temporary(wt);
}
if (groups == 1) {
gemm_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
flip,
s);
} else {
gemm_grouped_conv(
encoder,
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip,
s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,217 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_unfold_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array unfold_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = mat_K / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_unfold_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
int mat_N = params.O; // O
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded =
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
wt_reshaped.copy_shared_buffer(
wt,
{1, mat_K},
{false, false, /* col_contiguous */ true},
wt.data_size());
// Single batch.
Shape batch_shape{1};
Strides a_batch_strides{0};
Strides b_batch_strides{0};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
1, // groups
flip);
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,231 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/conv/conv.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T, int NDIM>
__global__ void naive_grouped_unfold_transpose_nd(
const T* in,
T* out,
int filter_size,
int out_pixels,
const __grid_constant__ ConvParams<NDIM> params) {
auto block = cg::this_thread_block();
auto tid = block.group_index();
auto lid = block.thread_index();
int index_batch = tid.z / out_pixels; // [0, N)
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
int index_wt_spatial =
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
if (index_wt_spatial >= filter_size / params.C) {
return;
}
in += tid.y; // [0, C)
out += tid.z * filter_size + tid.y * (filter_size / params.C);
bool valid = index_batch < params.N;
// Get the coordinates in input.
int index_in[NDIM] = {};
int wt_stride = 1;
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int index_out = index_out_spatial % params.out_spatial_dims[i];
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
out += index_wt * wt_stride;
if (params.flip) {
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
}
int index = index_out * params.strides[i] - params.padding[i] +
index_wt * params.kernel_dilation[i];
int index_max =
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
valid &= (index >= 0) && (index < index_max) &&
(index % params.input_dilation[i] == 0);
index_in[i] = index / params.input_dilation[i];
index_out_spatial /= params.out_spatial_dims[i];
index_wt_spatial /= params.wt_spatial_dims[i];
wt_stride *= params.wt_spatial_dims[i];
}
if (valid) {
int in_offset = index_batch * params.in_strides[0];
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
in_offset += index_in[i] * params.in_strides[i + 1];
}
*out = in[in_offset];
} else {
*out = T{0};
}
}
} // namespace cu
template <int NDIM>
array grouped_unfold_transpose_inputs_nd(
cu::CommandEncoder& encoder,
const array& in,
int mat_M,
int mat_K,
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
encoder.add_temporary(unfolded);
int filter_size = params.C;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
filter_size *= params.wt_spatial_dims[i];
}
int out_pixels = 1;
#pragma unroll
for (int i = 0; i < NDIM; ++i) {
out_pixels *= params.out_spatial_dims[i];
}
int wt_spatial_size = (mat_K * params.groups) / params.C;
dim3 block_dims;
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
dim3 num_blocks;
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
num_blocks.y = params.C;
num_blocks.z = mat_M;
encoder.set_input_array(in);
encoder.set_output_array(unfolded);
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
encoder.add_kernel_node(
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
filter_size,
out_pixels,
params);
});
return unfolded;
}
template <int NDIM>
void gemm_grouped_conv_nd(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
ConvParams<NDIM>& params,
Stream s) {
// Get gemm shapes.
int C_per_group = params.C / params.groups;
int O_per_group = params.O / params.groups;
int mat_M = out.size() / params.O; // N * H_out * W_out
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
int mat_N = O_per_group; // O_per_group
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
encoder, in, mat_M, mat_K, mat_N, params);
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
array wt_view(
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
// Batch with size of groups.
Shape batch_shape{params.groups};
Strides a_batch_strides{mat_K};
Strides b_batch_strides{mat_N * mat_K};
// Run matmul.
CublasGemm gemm(
encoder.device(),
in.dtype(),
false, // a_transposed
mat_M, // a_rows
mat_K, // a_cols
mat_K * params.groups, // lda
true, // b_transposed
mat_K, // b_rows
mat_N, // b_cols
mat_K, // ldb
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.set_out(
out.dtype(),
false, // out_transposed
mat_M, // out_rows
mat_N, // out_cols
mat_N * params.groups, // out_ld
params.groups, // batch_count
mat_N); // batch_stride
gemm.run(
encoder,
out,
in_unfolded,
wt_reshaped,
batch_shape,
a_batch_strides,
b_batch_strides);
}
void gemm_grouped_conv(
cu::CommandEncoder& encoder,
const array& in,
const array& wt,
array& out,
const std::vector<int>& strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation,
int groups,
bool flip,
Stream s) {
int conv_ndim = in.ndim() - 2;
if (conv_ndim < 1 || conv_ndim > 3) {
throw std::runtime_error(
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
}
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
ConvParams<ndim_constant()> params(
in,
wt,
out,
strides,
padding,
kernel_dilation,
input_dilation,
groups,
flip);
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
});
}
} // namespace mlx::core

View File

@@ -15,8 +15,8 @@ void copy_gpu_inplace(
int64_t offset_out,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) {
std::optional<array> dynamic_offset_in,
std::optional<array> dynamic_offset_out) {
if (out.size() == 0) {
return;
}
@@ -44,6 +44,16 @@ void copy_gpu_inplace(
strides_vec[0]);
} else {
if (dynamic_offset_in || dynamic_offset_out) {
if (!dynamic_offset_in) {
dynamic_offset_in = array(0, int64);
encoder.add_temporary(*dynamic_offset_in);
}
if (!dynamic_offset_out) {
dynamic_offset_out = array(0, int64);
encoder.add_temporary(*dynamic_offset_out);
}
encoder.set_input_array(*dynamic_offset_in);
encoder.set_input_array(*dynamic_offset_out);
copy_general_dynamic(
encoder,
ctype,
@@ -54,8 +64,8 @@ void copy_gpu_inplace(
shape_collapsed,
strides_vec[0],
strides_vec[1],
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
*dynamic_offset_in,
*dynamic_offset_out);
} else {
copy_general(
encoder,

View File

@@ -22,7 +22,7 @@ __global__ void copy_s(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in[0]);
out_vec[i] = cast_to<Out>(in[0]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -43,7 +43,7 @@ __global__ void copy_v(const In* in, Out* out, IdxT size) {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = cast_to<Out>(in_vec.val[i]);
out_vec[i] = cast_to<Out>(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
@@ -65,23 +65,18 @@ void copy_contiguous(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
constexpr int N_READS = 16 / sizeof(InType);
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());

View File

@@ -10,37 +10,80 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_gg_nd(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index, shape.data(), strides_in.data(), strides_out.data());
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto in_stride_x = strides_in[NDIM - 1];
auto out_stride_x = strides_out[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data());
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_gg(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides_out,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [idx_in, idx_out] = elem_to_loc(
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto in_stride_x = strides_in[ndim - 1];
auto out_stride_x = strides_out[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto [idx_in, idx_out] = elem_to_loc(
index_rest * shape_x,
shape.data(),
strides_in.data(),
strides_out.data(),
ndim);
auto in_vec =
load_vector<N_READS>(in + idx_in, index_x, shape_x, in_stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x);
}
} // namespace cu
@@ -69,34 +112,52 @@ void copy_general(
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = data_size / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant(), 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
data_size,
rest,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
auto kernel = cu::copy_gg<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_gg<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
data_size,
rest,
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@@ -74,14 +74,16 @@ void copy_general_dynamic(
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
copy_gg_dynamic_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
cu::copy_gg_dynamic_nd<
InType,
OutType,
IdxT,
dims_constant()>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
@@ -92,13 +94,12 @@ void copy_general_dynamic(
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto [num_blocks, block_dims] = get_launch_args(out, large());
encoder.add_kernel_node(
kernel,
cu::copy_gg_dynamic<InType, OutType, IdxT>,
num_blocks,
block_dims,
0,
in_ptr,
out_ptr,
out.size(),

View File

@@ -10,33 +10,67 @@ namespace cu {
namespace cg = cooperative_groups;
template <typename In, typename Out, typename IdxT, int NDIM>
template <typename In, typename Out, typename IdxT, int NDIM, int N_READS>
__global__ void copy_g_nd(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
out[index] = CastOp<In, Out>{}(in[idx_in]);
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[NDIM - 1];
auto stride_x = strides[NDIM - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc_nd<NDIM>(index_rest * shape_x, shape.data(), strides.data());
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename In, typename Out, typename IdxT>
template <typename In, typename Out, typename IdxT, int N_READS>
__global__ void copy_g(
const In* in,
Out* out,
IdxT size,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides_in,
const __grid_constant__ Strides strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
out[index] = CastOp<In, Out>{}(in[idx_in]);
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = CastOp<In, Out>{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
} // namespace cu
@@ -61,33 +95,49 @@ void copy_general_input(
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in_ptr,
out_ptr,
out.size(),
rest,
const_param(shape),
const_param(strides_in),
ndim);

View File

@@ -0,0 +1,272 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
encoder.add_temporary(workspace);
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -0,0 +1,164 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
}
}
return alignment;
}
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
//
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
} // namespace mlx::core

View File

@@ -0,0 +1,379 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
@@ -28,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}();
return cache_size;
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return use_graphs;
}
} // namespace
Device::Device(int device) : device_(device) {
@@ -54,6 +62,10 @@ Device::Device(int device) : device_(device) {
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
// The cudnn handle is used by Convolution.
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
// Initialize the jit module cache here ensures it is not
// unloaded before any evaluation is done
get_jit_module_cache();
}
Device::~Device() {
@@ -81,34 +93,22 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
if (discard) {
if (!use_cuda_graphs()) {
return;
}
// Extract and add as single kernel node when possible.
size_t num_nodes;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, NULL, &num_nodes));
if (num_nodes == 1) {
cudaGraphNode_t captured_node;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, &captured_node, &num_nodes));
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(captured_node, &type));
if (type == cudaGraphNodeTypeKernel) {
CUDA_KERNEL_NODE_PARAMS params;
CHECK_CUDA_ERROR(cuGraphKernelNodeGetParams(captured_node, &params));
enc.add_kernel_node(params);
return;
}
graph.end_capture(enc.stream());
if (discard) {
return;
}
// Otherwise add the captured graph as subgraph.
enc.add_graph_node(graph);
}
@@ -119,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -196,31 +199,29 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
}
}
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
}
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
for (auto& [_, graph_exec] : graphs) {
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
}
graphs.clear();
}
CommandEncoder::~CommandEncoder() {
clear_graphs(graph_cache_);
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
active_outputs_.push_back(id);
@@ -236,12 +237,19 @@ void CommandEncoder::add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
}
@@ -249,7 +257,24 @@ void CommandEncoder::add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
@@ -259,6 +284,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params);
}
@@ -275,19 +301,32 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
}
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
graph_,
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
}
graph_key_ += ".";
@@ -297,7 +336,7 @@ void CommandEncoder::commit() {
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
@@ -311,31 +350,24 @@ void CommandEncoder::commit() {
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
graph_exec = nullptr;
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
CHECK_CUDA_ERROR(
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy
if (graph_cache_.size() > cuda_graph_cache_size()) {
clear_graphs(graph_cache_);
}
// Reset state
node_count_ = 0;
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
node_map_.clear();
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
graph_ = CudaGraph(device_);
}
// Put completion handlers in a batch.

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
@@ -20,7 +21,7 @@ class CommandEncoder {
struct CaptureContext {
CaptureContext(CommandEncoder& enc);
~CaptureContext();
cudaGraph_t graph;
CudaGraph graph;
CommandEncoder& enc;
bool discard{false};
};
@@ -31,7 +32,6 @@ class CommandEncoder {
};
explicit CommandEncoder(Device& d);
~CommandEncoder();
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -47,29 +47,35 @@ class CommandEncoder {
void set_output_array(const array& arr);
template <typename F, typename... Params>
void
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) {
void add_kernel_node(
F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
constexpr size_t num = sizeof...(Params);
void* ptrs[num];
size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs);
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
}
void add_kernel_node(
CUfunction func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
void
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params);
void add_kernel_node(
void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
@@ -92,6 +98,9 @@ class CommandEncoder {
void synchronize();
private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode {
cudaGraphNode_t node;
// K = kernel
@@ -106,7 +115,7 @@ class CommandEncoder {
Device& device_;
CudaStream stream_;
cudaGraph_t graph_;
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
@@ -117,7 +126,7 @@ class CommandEncoder {
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;

View File

@@ -1,15 +0,0 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core::cu {
template <typename T>
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const {
return start + i * step;
}
};
} // namespace mlx::core::cu

View File

@@ -49,11 +49,7 @@ inline __device__ void atomic_add(__half* out, __half val) {
}
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {

View File

@@ -204,6 +204,12 @@ struct Power {
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
// Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) {
return 0;
}
}
while (exp) {
if (exp & 1) {
res *= base;

View File

@@ -6,7 +6,6 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
@@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@@ -32,36 +32,136 @@ using Strides = cuda::std::array<int64_t, MAX_NDIM>;
template <typename T, int N>
struct alignas(sizeof(T) * N) AlignedVector {
T val[N];
__device__ T& operator[](int i) {
return val[i];
}
__device__ T operator[](int i) const {
return val[i];
}
};
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
inline __host__ __device__ bool is_aligned(T* x) {
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> unsafe_load_vector(
const T* ptr,
uint32_t offset) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
}
template <int N, typename T>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset) {
if (is_aligned<N>(ptr)) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = ptr[offset * N + i];
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N>
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
}
return v;
}
}
template <int N, typename T, typename SizeT>
inline __device__ AlignedVector<T, N> load_vector(
const T* ptr,
uint32_t offset,
SizeT size,
int64_t stride,
T fallback) {
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
return from[offset];
} else {
AlignedVector<T, N> v;
#pragma unroll
for (int i = 0; i < N; ++i) {
v[i] =
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
}
return v;
}
}
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
}
// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;
__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}
__host__ __device__ auto operator[](int i) const {
return it[i * stride];
template <int N, typename T>
inline __device__ void
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
if (is_aligned<N>(ptr)) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
};
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[offset * N + i] = vec[i];
}
}
}
template <int N, typename T, typename SizeT>
inline __device__ void store_vector(
T* ptr,
uint32_t offset,
const AlignedVector<T, N>& vec,
SizeT size,
int64_t stride) {
if (is_aligned<N>(ptr) && (offset + 1) * N <= size && stride == 1) {
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
to[offset] = vec;
} else {
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
ptr[stride * (offset * N + i)] = vec[i];
}
}
}
///////////////////////////////////////////////////////////////////////////////
// Type limits utils

View File

@@ -0,0 +1,56 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace mlx::core::distributed

View File

@@ -36,18 +36,15 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
// Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
encoder.add_temporary(s);
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
}

View File

@@ -1,206 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core::cu {
namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
{static_cast<int>(batch_count * 3)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void Matmul::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
float alpha,
float beta) {
auto batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)},
uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
static_cast<int64_t>(M_) * N_,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
run_impl(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core::cu

View File

@@ -7,10 +7,12 @@
#include <fmt/format.h>
namespace mlx::core::cu {
namespace mlx::core {
namespace {
struct CublasPreference {
CublasPreference(Device& device) {
CublasPreference(cu::Device& device) {
// The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
// for Hopper+:
// https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
@@ -33,7 +35,7 @@ struct CublasPreference {
cublasLtMatmulPreference_t pref_{nullptr};
};
cublasLtMatmulPreference_t cublas_preference(Device& device) {
cublasLtMatmulPreference_t cublas_preference(cu::Device& device) {
static CublasPreference pref(device);
return pref.pref_;
}
@@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return CUBLAS_COMPUTE_64F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
}
}
@@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) {
return CUDA_C_32F;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Matmul: {}.", dtype_to_string(dtype)));
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
}
}
@@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
return desc;
}
Matmul::Matmul(
Device& device,
} // namespace
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -155,8 +159,8 @@ Matmul::Matmul(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
}
Matmul::Matmul(
Device& device,
CublasGemm::CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -171,7 +175,7 @@ Matmul::Matmul(
int64_t a_batch_stride,
int64_t b_batch_stride,
int64_t c_batch_stride)
: Matmul(
: CublasGemm(
device,
dtype,
a_transposed,
@@ -190,7 +194,7 @@ Matmul::Matmul(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
}
Matmul::~Matmul() {
CublasGemm::~CublasGemm() {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
@@ -198,7 +202,92 @@ Matmul::~Matmul() {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
}
void Matmul::run_impl(
void CublasGemm::set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder,
out,
a,
b,
c,
batch_shape,
a_batch_strides,
b_batch_strides,
c_batch_strides,
alpha,
beta);
return;
}
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c.data<void>(),
alpha,
beta);
}
void CublasGemm::execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -213,7 +302,7 @@ void Matmul::run_impl(
matmul_desc_,
a_desc_,
b_desc_,
out_desc_, // TODO should that be c_desc is it's set?
c ? c_desc_ : out_desc_,
out_desc_,
pref_,
1,
@@ -226,8 +315,10 @@ void Matmul::run_impl(
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
allocator::malloc(heuristic_.workspaceSize),
allocator::malloc(nbytes),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
@@ -254,29 +345,4 @@ void Matmul::run_impl(
encoder.stream()));
}
void Matmul::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c /* = std::nullopt */,
float alpha /* = 1 */,
float beta /* = 0 */) {
encoder.set_input_array(a);
encoder.set_input_array(b);
if (c) {
encoder.set_input_array(*c);
}
encoder.set_output_array(out);
run_impl(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
c ? c->data<void>() : nullptr,
alpha,
beta);
}
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -5,13 +5,13 @@
#include "mlx/backend/cuda/device.h"
#include <cublasLt.h>
#include <optional>
namespace mlx::core::cu {
class Matmul {
namespace mlx::core {
class CublasGemm {
public:
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -25,8 +25,8 @@ class Matmul {
int64_t a_batch_stride,
int64_t b_batch_stride);
Matmul(
Device& device,
CublasGemm(
cu::Device& device,
Dtype dtype,
bool a_transposed,
uint64_t a_rows,
@@ -42,25 +42,50 @@ class Matmul {
int64_t b_batch_stride,
int64_t c_batch_stride);
~Matmul();
~CublasGemm();
// The output's descriptor is inferred from inputs by default, use this method
// for unusual output.
void set_out(
Dtype dtype,
bool transposed,
uint64_t rows,
uint64_t cols,
int64_t ld,
int32_t batch_count,
int64_t batch_stride);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const std::optional<array>& c = std::nullopt,
float alpha = 1,
float beta = 0);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides);
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
void run_batched(
cu::CommandEncoder& encoder,
@@ -68,15 +93,14 @@ class Matmul {
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta);
private:
void run_impl(
void execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
@@ -97,4 +121,4 @@ class Matmul {
cublasLtMatmulHeuristicResult_t heuristic_;
};
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -4,16 +4,16 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
namespace mlx::core::cu {
namespace mlx::core {
void Matmul::run_batched(
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) {
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -22,7 +22,7 @@ void Matmul::run_batched(
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -33,16 +33,16 @@ void Matmul::run_batched(
}
}
void Matmul::run_batched(
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides,
const mlx::core::Strides& c_batch_strides,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
encoder.set_input_array(a);
@@ -56,7 +56,7 @@ void Matmul::run_batched(
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) {
run_impl(
execute(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
@@ -70,4 +70,4 @@ void Matmul::run_batched(
}
}
} // namespace mlx::core::cu
} // namespace mlx::core

View File

@@ -0,0 +1,327 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include <cooperative_groups.h>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ Shape batch_shape,
const __grid_constant__ Strides a_batch_strides,
const __grid_constant__ Strides b_batch_strides,
const __grid_constant__ Strides c_batch_strides,
int64_t batch_stride,
int batch_ndim,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data(),
batch_ndim);
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
} // namespace cu
namespace {
void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t)));
}
} // namespace
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3),
{batch_count * 3},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_mm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
}
void CublasGemm::run_batched(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& c,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
const Strides& c_batch_strides,
float alpha,
float beta) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{batch_count * 4},
uint64);
encoder.add_temporary(pointers);
encoder.set_output_array(pointers);
int block_dims = std::min(batch_count, 256);
int num_blocks = cuda::ceil_div(batch_count, block_dims);
int64_t batch_stride = M_ * N_;
int item_size = out.itemsize();
int ndim = batch_shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_nd<ndim_constant()>,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count;
execute(
encoder,
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
reinterpret_cast<void*>(c_pointers),
alpha,
beta);
}
} // namespace mlx::core

View File

@@ -11,7 +11,6 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups;
static constexpr int n_per_thread = 4;
static constexpr int rows_per_block = 8;
template <typename T, int rows_per_block, int n_per_thread>
@@ -28,12 +27,13 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
float sum = 0.0f;
for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) {
auto local_mat = load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = load_vector<n_per_thread>(vec + col, 0);
auto local_mat =
unsafe_load_vector<n_per_thread>(mat + row * cols + col, 0);
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll
for (int j = 0; j < n_per_thread; ++j) {
sum += static_cast<float>(local_mat.val[j]) *
static_cast<float>(local_vec.val[j]);
sum +=
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
}
}
@@ -74,8 +74,22 @@ __global__ void gemv_batched(
}
bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) {
return K % (WARP_SIZE * n_per_thread) == 0 &&
((M == 1 && b_transposed) || (N == 1 && !a_transposed));
return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed));
}
template <typename F>
void dispatch_n_per_thread(int n_per_thread, F&& f) {
switch (n_per_thread) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
}
}
void gemv(
@@ -114,33 +128,45 @@ void gemv(
rows = M;
}
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols);
int n_per_t;
if (K % 128 == 0 && is_aligned<4>(mat) && is_aligned<4>(vec)) {
n_per_t = 4;
} else if (K % 64 == 0 && is_aligned<2>(mat) && is_aligned<2>(vec)) {
n_per_t = 2;
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
n_per_t = 1;
}
dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) {
if (batch_count == 1) {
auto kernel = gemv_single<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
num_blocks_x,
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols);
} else {
auto kernel = gemv_batched<DataType, rows_per_block, n_per_thread()>;
encoder.add_kernel_node(
kernel,
dim3{num_blocks_x, batch_count},
block_dims,
0,
mat,
vec,
out.data<DataType>(),
rows,
cols,
const_param(batch_shape),
mat_strides,
vec_strides,
batch_shape.size());
}
});
});
}

View File

@@ -29,12 +29,12 @@ void append_indices_arg(
const std::vector<array>& inputs,
int nidx,
int idx_ndim) {
std::vector<const void*> indices(nidx);
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
}
args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim);
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].shape().begin(),
@@ -42,7 +42,7 @@ void append_indices_arg(
indices_shape.data() + i * idx_ndim);
}
args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim);
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].strides().begin(),
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_gather, std::move(kernel_names));
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
});
cu::KernelArgs args;
@@ -128,8 +128,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_scatter, std::move(kernel_names));
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
});
cu::KernelArgs args;
@@ -229,8 +229,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
auto [num_blocks, block_dims] = get_launch_args(upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
}
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
return std::make_tuple(
false, jit_source_gather_axis, std::move(kernel_names));
});
size_t idx_size_pre = 1;
@@ -317,8 +318,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
}
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
return std::make_tuple(
false, jit_source_scatter_axis, std::move(kernel_names));
});
size_t idx_size_pre = 1;
@@ -421,8 +423,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
}
} // namespace mlx::core

View File

@@ -9,7 +9,6 @@
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <unordered_map>
#include <fmt/format.h>
#include <nvrtc.h>
@@ -68,9 +67,11 @@ const std::string& cccl_dir() {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
path = env;
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
}
return std::string();
}();
@@ -102,8 +103,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
@@ -118,15 +119,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) {
return false;
}
ptx->resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size);
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
}
}
return true;
@@ -136,7 +137,7 @@ bool read_cached_ptx(
void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
@@ -218,85 +219,85 @@ constexpr const char* g_headers[] = {
jit_source_utils,
};
} // namespace
JitModule::JitModule(
void compile(
Device& device,
const std::string& module_name,
const KernelBuilder& builder) {
// Check cache.
std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
// Create program.
auto [source_code, kernel_names] = builder();
nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog,
source_code.c_str(),
(module_name + ".cu").c_str(),
std::size(g_headers),
g_headers,
g_include_names));
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
&prog,
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size, 0);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
const std::string& source,
const std::vector<std::string>& kernel_names,
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
// Create the program
nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog,
source.c_str(),
(module_name + ".cu").c_str(),
std::size(g_headers),
g_headers,
g_include_names));
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
&prog,
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
}
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -313,31 +314,85 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel;
kernels[name] = std::make_pair(kernel, false);
}
}
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) {
throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name));
}
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}
return it->second.first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
static std::unordered_map<std::string, JitModule> map;
return map;
}
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder) {
static std::unordered_map<std::string, JitModule> map;
const KernelBuilder& builder,
bool cache) {
auto& map = get_jit_module_cache();
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first;
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
}
return it->second;
}

View File

@@ -19,7 +19,8 @@ namespace mlx::core::cu {
class Device;
using KernelBuilderResult = std::pair<
using KernelBuilderResult = std::tuple<
/* precompiled */ bool,
/* source code */ std::string,
/* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>;
@@ -40,19 +41,19 @@ struct KernelArgs {
}
template <typename T>
void append(std::vector<T> vec) {
if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null.
append(std::monostate{});
} else {
append_ptr(vec.data());
storage_.emplace_back(std::move(vec));
}
void append(SmallVector<T> vec) {
storage_.emplace_back(std::move(vec));
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
}
template <typename T>
void append(const std::vector<T>& vec) {
append(SmallVector<T>(vec.begin(), vec.end()));
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(std::vector<T> vec) {
void append_ndim(SmallVector<T> vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@@ -68,17 +69,19 @@ struct KernelArgs {
private:
std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store
// temporary values untill kernel is launched.
// The cuGraphAddKernelNode API requires passing pointers to arguments so
// store temporary values until the node is created.
using Arg = std::variant<
std::monostate,
CUdeviceptr,
bool,
int32_t,
uint32_t,
int64_t,
std::vector<const void*>,
std::vector<int32_t>,
std::vector<int64_t>>;
float,
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
std::deque<Arg> storage_;
};
@@ -87,21 +90,27 @@ class JitModule {
JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder);
const KernelBuilder& builder,
bool cache);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name);
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder);
const KernelBuilder& builder,
bool use_disk_cache = true);
} // namespace mlx::core::cu

View File

@@ -30,4 +30,25 @@ std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2) {
return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz));
}
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
} // namespace mlx::core

View File

@@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@@ -120,53 +120,19 @@ dim3 get_2d_grid_dims(
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Return a block size that achieves maximum potential occupancy for kernel.
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
// assuming each thread handles |work_per_thread| elements of |arr|.
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
bool large,
int work_per_thread = 1) {
size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = max_occupancy_block_dim(kernel);
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks;
if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
} else {
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
}
return std::make_tuple(num_blocks, block_dim);
}
int work_per_thread = 1);
template <typename T>
inline std::tuple<dim3, uint> get_launch_args(
T kernel,
const array& arr,
bool large,
int work_per_thread = 1) {
inline std::tuple<dim3, uint>
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
return get_launch_args(
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
}
} // namespace mlx::core

View File

@@ -10,8 +10,6 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core {
@@ -74,9 +72,11 @@ __global__ void layer_norm(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
}
sum = BlockReduceT{block, temp}.Sum(sum);
@@ -87,11 +87,18 @@ __global__ void layer_norm(
float normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
normalizer += t * t;
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
normalizer += t * t;
}
}
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
@@ -100,17 +107,15 @@ __global__ void layer_norm(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T bn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(b, b_stride), bn, axis_size);
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
auto bn = load_vector<N_READS>(b, index, axis_size, b_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float norm = (static_cast<float>(xn[i]) - mean) * normalizer;
xn[i] = wn[i] * static_cast<T>(norm) + bn[i];
}
cub::StoreDirectBlocked(index, out, xn, axis_size);
store_vector<N_READS>(out, index, xn, axis_size);
}
}
@@ -143,9 +148,11 @@ __global__ void layer_norm_vjp(
float sum = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS] = {};
cub::LoadDirectBlocked(index, x, xn, axis_size);
sum += static_cast<float>(cub::ThreadReduce(xn, cuda::std::plus<>{}));
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
sum += static_cast<float>(xn[i]);
}
}
sum = BlockReduceF{block, temp.f}.Sum(sum);
@@ -155,19 +162,28 @@ __global__ void layer_norm_vjp(
// Normalizer.
float3 factors = {};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T xn[N_READS];
T wn[N_READS] = {};
T gn[N_READS] = {};
auto index = r * BLOCK_DIM + block.thread_rank();
cub::LoadDirectBlocked(index, x, xn, axis_size, mean);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
if ((index + 1) * N_READS <= axis_size) {
auto xn = load_vector<N_READS>(x, index);
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
} else {
for (int i = index * N_READS; i < axis_size; ++i) {
float t = static_cast<float>(x[i]) - mean;
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f3(factors, {wg, wg * t, t * t});
}
}
}
factors = BlockReduceF3{block, temp.f3}.Reduce(factors, plus_f3, {});
@@ -179,12 +195,10 @@ __global__ void layer_norm_vjp(
// Outputs.
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
auto index = r * BLOCK_DIM + block.thread_rank();
T xn[N_READS];
T wn[N_READS];
T gn[N_READS];
cub::LoadDirectBlocked(index, x, xn, axis_size);
cub::LoadDirectBlocked(index, g, gn, axis_size);
cub::LoadDirectBlocked(index, StridedIterator(w, w_stride), wn, axis_size);
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float xi = (static_cast<float>(xn[i]) - mean) * normalizer;
float wi = wn[i];
@@ -194,9 +208,9 @@ __global__ void layer_norm_vjp(
wn[i] = gi * xi;
}
}
cub::StoreDirectBlocked(index, gx, xn, axis_size);
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
cub::StoreDirectBlocked(index, gw, wn, axis_size);
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
}
@@ -257,14 +271,15 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
constexpr uint32_t N_READS = 4;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
b.data<DataType>(),
@@ -364,10 +379,10 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gw_temp);
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::layer_norm_vjp<
DataType,
has_w_constant.value,
@@ -377,6 +392,7 @@ void LayerNormVJP::eval_gpu(
kernel,
n_rows,
block_dim(),
0,
x.data<DataType>(),
w.data<DataType>(),
g.data<DataType>(),

View File

@@ -43,20 +43,19 @@ __global__ void logsumexp(const T* in, T* out, int axis_size) {
AccT maxval = Limits<AccT>::finite_min();
AccT normalizer = 0;
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) {
AccT vals[N_READS];
cub::LoadDirectBlocked(
r * BLOCK_DIM + block.thread_rank(),
make_cast_iterator<AccT>(in),
vals,
axis_size,
Limits<AccT>::min());
auto index = r * BLOCK_DIM + block.thread_rank();
auto vals = load_vector<N_READS>(in, index, axis_size, Limits<T>::min());
prevmax = maxval;
maxval = max_op(maxval, cub::ThreadReduce(vals, max_op));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
maxval = max_op(maxval, static_cast<AccT>(vals[i]));
}
// Online normalizer calculation for softmax:
// https://github.com/NVIDIA/online-softmax
normalizer = normalizer * softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer = normalizer + softmax_exp(vals[i] - maxval);
normalizer =
normalizer + softmax_exp(static_cast<AccT>(vals[i]) - maxval);
}
}
@@ -143,14 +142,15 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
constexpr int N_READS = 4;
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
in.data<DataType>(),
out.data<DataType>(),
axis_size);

View File

@@ -2,6 +2,7 @@
#pragma once
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
@@ -20,7 +21,11 @@ class LRUCache {
using const_iterator = typename list_type::const_iterator;
using map_type = M<K, iterator>;
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
explicit LRUCache(size_t capacity) : capacity_(capacity) {
if (capacity == 0) {
throw std::runtime_error("LRUCache requires capacity > 0.");
}
}
size_t size() const {
return map_.size();
@@ -84,6 +89,14 @@ class LRUCache {
return vlist_.erase(pos);
}
V& operator[](const K& key) {
auto it = find(key);
if (it == end()) {
it = emplace(key, V{}).first;
}
return it->second;
}
private:
void trim() {
while (map_.size() > capacity_) {

View File

@@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b);
return;
}
matmul.run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
cu::Matmul matmul(
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
@@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back(),
c_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
matmul.run(encoder, out, a, b, c, alpha_, beta_);
return;
}
matmul.run_batched(
gemm.run(
encoder,
out,
a,

View File

@@ -1,11 +1,47 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core::cu {
namespace mlx::core {
namespace cu {
bool is_available() {
return false;
}
} // namespace mlx::core::cu
} // namespace cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#include "mlx/distributed/primitives.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(BlockMaskedMM)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace distributed {
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -1,103 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/arange.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/distributed/primitives.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cassert>
namespace mlx::core {
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Arange::eval_gpu");
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
});
}
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \
}
NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU(Load)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -2,30 +2,17 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int group_size, int bits>
__global__ void
affine_quantize(const T* w, uint8_t* out, T* scales, T* biases, size_t size) {
@@ -240,144 +227,102 @@ __global__ void affine_dequantize(
}
} // namespace cu
namespace {
inline array ensure_row_contiguous(
const array& x,
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
} // namespace
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
auto w = ensure_row_contiguous(w_pre, enc, s);
enc.set_input_array(w);
if (dequantize_) {
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(out);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
enc.set_output_array(out);
enc.set_output_array(scales);
enc.set_output_array(biases);
}
auto dtype = dequantize_ ? outputs[0].dtype() : inputs[0].dtype();
// Treat uint32 as uint8 in kernel
int uint8_per_uint32 = 4;
int packs_per_int = (bits_ == 3 || bits_ == 5) ? 8
: bits_ == 6 ? 4
: 8 / bits_;
int per_thread = dequantize_ ? packs_per_int : group_size_ / WARP_SIZE;
size_t size =
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
// Calculate the number of elements per thread
int per_thread = group_size_ / WARP_SIZE;
size_t size = w.size() / per_thread;
// Calculate the thread grid that we need to launch
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() /= per_thread;
if (dequantize_) {
grid_shape.back() *= uint8_per_uint32;
} else {
grid_shape.back() /= per_thread;
}
dispatch_float_types(dtype, "affine_quantize", [&](auto type_tag) {
enc.set_input_array(w);
enc.set_output_array(wq);
enc.set_output_array(scales);
enc.set_output_array(biases);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (dequantize_) {
auto kernel =
cu::affine_dequantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<uint8_t>(),
inputs[1].data<DataType>(),
inputs[2].data<DataType>(),
out.data<DataType>(),
out.size());
} else {
auto kernel =
cu::affine_quantize<DataType, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
w.data<DataType>(),
out.data<uint8_t>(),
outputs[1].data<DataType>(),
outputs[2].data<DataType>(),
w.size());
}
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_quantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
w.data<T>(),
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.size());
});
});
});
}
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s) {
// Calculate how many numbers we pack together. For 2, 4, 8 bits we pack in
// one uint8, for 3, 6 in 3 uint8 and for 5 in 5 uint8.
constexpr int uint8_per_uint32 = 4;
int packs_per_int;
switch (bits_) {
case 3:
case 5:
packs_per_int = 8;
break;
case 6:
packs_per_int = 4;
break;
default:
packs_per_int = 8 / bits_;
}
size_t size = w.size() / packs_per_int;
bool large = size > UINT_MAX;
auto grid_shape = w.shape();
grid_shape.back() *= uint8_per_uint32;
enc.set_input_array(wq);
enc.set_input_array(scales);
enc.set_input_array(biases);
enc.set_output_array(w);
dispatch_float_types(w.dtype(), "affine_quantize", [&](auto type_tag) {
dispatch_groups(group_size_, [&](auto group_size) {
dispatch_bits(bits_, [&](auto bits) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto kernel = cu::affine_dequantize<T, group_size.value, bits.value>;
auto [num_blocks, block_dims] =
get_launch_args(size, grid_shape, w.strides(), large);
enc.add_kernel_node(
kernel,
num_blocks,
block_dims,
0,
wq.data<uint8_t>(),
scales.data<T>(),
biases.data<T>(),
w.data<T>(),
w.size());
});
});
});

View File

@@ -0,0 +1,80 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
inline array ensure_row_contiguous(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (!x.flags().row_contiguous) {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
} else {
return x;
}
}
inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
} // namespace
void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("Quantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);
if (dequantize_) {
auto wq = ensure_row_contiguous(inputs[0], enc, s);
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto biases = ensure_row_contiguous(inputs[2], enc, s);
auto& w = outputs[0];
w.set_data(allocator::malloc(w.nbytes()));
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
} else {
auto w = ensure_row_contiguous(inputs[0], enc, s);
auto& wq = outputs[0];
auto& scales = outputs[1];
auto& biases = outputs[2];
wq.set_data(allocator::malloc(wq.nbytes()));
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
namespace mlx::core {
void affine_quantize(
const array& w,
array& wq,
array& scales,
array& biases,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
void affine_dequantize(
const array& wq,
const array& scales,
const array& biases,
array& w,
int group_size_,
int bits_,
cu::CommandEncoder& enc,
const Stream& s);
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2025 Apple Inc.
namespace mlx::core {
namespace cu {
template <int bits, int wsize = 8>
inline constexpr __device__ short get_pack_factor() {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
template <int bits, int wsize = 8>
inline constexpr __device__ short get_bytes_per_pack() {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
} // namespace cu
template <typename F>
void dispatch_groups(int group_size, F&& f) {
switch (group_size) {
case 32:
f(std::integral_constant<int, 32>{});
break;
case 64:
f(std::integral_constant<int, 64>{});
break;
case 128:
f(std::integral_constant<int, 128>{});
break;
}
}
template <typename F>
void dispatch_bits(int bits, F&& f) {
switch (bits) {
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
case 4:
f(std::integral_constant<int, 4>{});
break;
case 5:
f(std::integral_constant<int, 5>{});
break;
case 6:
f(std::integral_constant<int, 6>{});
break;
case 8:
f(std::integral_constant<int, 8>{});
break;
}
}
} // namespace mlx::core

View File

@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbitsc,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbits,
grid,
block,
0,
keys.data<uint32_t>(),
out.data<uint8_t>(),
grid_dims,

Some files were not shown because too many files have changed in this diff Show More