Compare commits

..

92 Commits

Author SHA1 Message Date
Awni Hannun
4bce5f9b2d suppress gcc 10.1 warnings (#2679)
* suppress gcc 10.1 warnings

* suppress gcc 10.1 warnings
2025-10-17 12:09:21 -07:00
Anastasiia Filippova
e9eab527eb Nccl timeout (#2673)
* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
2025-10-14 12:29:54 -07:00
Awni Hannun
36ca62dba8 remove unused unary file (#2672) 2025-10-13 19:36:26 -07:00
Manuel Villanueva
9cbb1b0148 Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-13 14:36:45 -07:00
Fabrizio Milo
9bfc476d72 Normalize README bullet formatting (#2671) 2025-10-13 12:13:30 -07:00
Awni Hannun
25e2356316 speed up scalars (#2669) 2025-10-13 12:10:15 -07:00
Awni Hannun
226a1d24e0 Debug cuda conv (#2662)
* use t4

* use t4
2025-10-10 16:12:47 -07:00
Awni Hannun
630350ad3e Precise sigmoid (#2659)
* bump patch

* Sigmoid matches PyTorch and is more precise on tails
2025-10-10 10:05:23 -07:00
Awni Hannun
380aeb58ae enable admm low-precision cpu (#2661) 2025-10-10 09:50:54 -07:00
Awni Hannun
f37389d100 bump patch (#2658) 2025-10-10 08:36:41 -07:00
Awni Hannun
e89e8b4272 Export with callback (#2612)
* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
2025-10-08 19:24:33 -07:00
AN Long
85a8824a8c Fix cumulative operations when axis=None (#2653) 2025-10-08 15:25:38 -07:00
Awni Hannun
f5d4397e5c Fix fast synch when fence is waited before a command buffer is created (#2657) 2025-10-08 11:23:46 -07:00
Awni Hannun
343e33b6d5 fix all_gather vjp (#2654) 2025-10-07 06:05:23 -07:00
Angelos Katharopoulos
0073096dd1 Split name into directories for cuda jit (#2656) 2025-10-07 01:52:58 -07:00
Angelos Katharopoulos
e3d004fed9 Fix and refactor row-reduce (#2650) 2025-10-07 01:51:08 -07:00
Awni Hannun
a393435d28 Speed up compile for node with many parents (#2649) 2025-10-03 19:30:36 -07:00
Awni Hannun
a7a94b29d7 Fix compile when outputs change (#2648) 2025-10-03 08:40:57 -07:00
Daniel Yeh
22a5da76c8 Faster complex matmul (#2571) 2025-10-02 23:33:15 -07:00
Andrey Portnoy
287c63a093 Configure CMake to export compile_commands.json (#2645)
This helps enable LSP for code navigation using clangd.
2025-10-02 15:40:32 -07:00
Awni Hannun
1c9ae1eaa1 cuda fix flaky test (#2646) 2025-10-02 15:40:04 -07:00
Angelos Katharopoulos
c2c3e0b0a2 [CUDA] Add a small column specialization to reduce (#2642) 2025-10-02 14:41:05 -07:00
Awni Hannun
b0cc71ae71 Faster triu, tril, where with scalar (#2644) 2025-10-02 12:21:27 -07:00
Awni Hannun
e88f2d4a8e fix cross entropy axis param (#2641)
* fix cross entropy axis param

* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
9cee557423 Fix status message (#2638) 2025-10-01 16:43:45 -07:00
Awni Hannun
bbf1423953 wait for tasks in cuda (#2636) 2025-09-30 16:08:46 -07:00
Angelos Katharopoulos
eb24267b56 Compile now can attach arbitrary data to an entry (#2634) 2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5 fix for max block dim (#2631) 2025-09-29 08:59:25 -07:00
AN Long
e76a8dd5c5 Fix incorrect path and typos (#2630) 2025-09-28 06:03:04 -07:00
Cheng
b466dea982 [CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
2025-09-27 11:27:17 +09:00
Angelos Katharopoulos
7a6adda1e6 Bump the version (#2627) 2025-09-26 15:15:28 -07:00
Angelos Katharopoulos
1a9f820af6 Compiled should not end in broadcast (#2622) 2025-09-26 13:36:09 -07:00
Awni Hannun
d4f4ff3c5e Allow None input to compiled functions (#2621)
* Allow None input to compiled functions

* Allow None input to compiled functions
2025-09-25 08:42:23 -07:00
Jagrit Digani
7c7e48dbd1 New tuning for small K gemv (#2620)
* New tuning for small K gemv
2025-09-23 12:28:35 -07:00
Daniel Yeh
fbbf3b9b3e Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:15 -07:00
Daniel Yeh
bf01ad9367 fix (#2613)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:04 -07:00
Cheng
ae438d05fa [CUDA] Recycle CUDA events (#2604)
* Make CudaEvent a CudaHandle

* Add caching for CudaEvent

* Make sure cuda events are destroyed at last

* Fix headers

* SharedEvent => AtomicEvent

* RawCudaEvent => CudaEventHandle, CudaEventWrapper => CopyableCudaEvent

* Remove unneeded asserts
2025-09-23 10:42:03 +09:00
Awni Hannun
711a645807 avoid producing NaN in attention (#2608) 2025-09-22 13:10:43 -07:00
Josh Bleecher Snyder
aa9d44b3d4 implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness

Updates #2599
2025-09-22 10:09:45 -07:00
Awni Hannun
ec2ab42888 Lower sorted QMM gather threshold (#2609) 2025-09-19 18:22:55 -07:00
Cheng
787c0d90cd Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache

* Do not check cache thrashing in tests
2025-09-19 09:12:14 +09:00
Oleksandr Bilous
e8b604a6a3 fix: library loading for swift dynamic frameworks (#2568) 2025-09-18 13:54:59 -07:00
Awni Hannun
50cc09887f expose depends (#2606) 2025-09-18 10:06:15 -07:00
Umberto Mignozzetti
3f730e77aa Update export function example for array input (#2598)
After changing the shape to conform (same shapes for all objects), the example works.
2025-09-16 14:38:05 -07:00
Awni Hannun
caecbe876a no copy batch rope (#2595) 2025-09-15 14:23:48 -07:00
Umberto Mignozzetti
8afb6d62f2 Fix typo in average_gradients function call (#2594) 2025-09-15 11:29:21 -07:00
Awni Hannun
6ccfa603cd fix metal scan (#2591) 2025-09-15 11:01:57 -07:00
Umberto Mignozzetti
36cad99a11 Refactor code examples to use 'gelu' (#2592)
Updated code examples to use 'gelu' directly instead of 'nn.gelu'.
2025-09-15 09:47:02 -07:00
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57 Add sdpa with sinks (#2558)
* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
2025-09-10 14:53:00 -07:00
Gökdeniz Gülmez
db5443e831 Adding Relu2 (#2582)
* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-09-10 07:24:30 -07:00
Cheng
52b8384d10 Fix flaky addmm tests (#2581) 2025-09-10 14:22:22 +09:00
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
Cheng
c1e3340b23 Set ccache size before building (#2570) 2025-09-07 09:00:31 +09:00
XXXXRT666
8f163a367d typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods

* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014 Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-03 12:52:08 -07:00
Krishi Saripalli
c5d2937aa5 chore: Update Docs With Slice Copy Example (#2559)
* chore: updated docs with slice copy example

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-02 22:07:02 -07:00
Awni Hannun
b61a65e313 fix copies in sdpa (#2563) 2025-09-02 11:00:36 -07:00
wrmsr
04cbb4191c Fix dequantize python sig (#2562) 2025-09-01 11:50:20 -07:00
Artur Antonov
c5460762e7 Fix AdamW weight_decay default value in docstring (#2557) 2025-08-31 21:29:30 -07:00
Awni Hannun
8ce49cd39e fix quantized vjp for mxfp4 (#2555) 2025-08-29 10:06:15 -07:00
Awni Hannun
9c68b50853 version bump (#2554) 2025-08-29 06:54:17 -07:00
Awni Hannun
111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
827003d568 fix METAL quantization in JIT (#2553) 2025-08-28 18:26:25 -07:00
Awni Hannun
d363a76aa4 Bump xcode in circle (#2551)
* bump xcode in circle

* bump xcode in circle

* bump xcode in circle
2025-08-28 13:13:34 -07:00
Awni Hannun
70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5 [CUDA] fix sort (#2550)
* [CUDA] fix sort

* fix test
2025-08-27 19:48:43 -07:00
Cheng
31c6f6e33f [CUDA] Use ConcurrentContext in concatenate_gpu (#2549) 2025-08-28 09:30:08 +09:00
Awni Hannun
584d48458e link with nccl (#2546) 2025-08-27 10:01:07 -07:00
Cheng
5cf984ca87 Separate cpu compilation cache by versions (#2548) 2025-08-27 11:25:15 +09:00
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
Awni Hannun
3dcb286baf Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default

* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9 [CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives

* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Awni Hannun
2ca75bb529 Remove nccl install in release (#2542) 2025-08-25 15:20:18 -07:00
Awni Hannun
db14e29a0b allow pathlib.Path to save/load functions (#2541) 2025-08-25 14:58:49 -07:00
Awni Hannun
d2f540f4e0 Use nccl header only when nccl is not present (#2539)
* use nccl header only when nccl is not present

* larger machine for cuda build
2025-08-25 14:17:25 -07:00
Cheng
333ffea273 [CUDA] Remove thrust in arange (#2535) 2025-08-24 16:22:36 +09:00
Cheng
f55b6f1f2f Enable COMPILE_WARNING_AS_ERROR for linux builds in CI (#2534) 2025-08-24 15:33:08 +09:00
Awni Hannun
30561229c7 Fix allocation bug in NCCL (#2530) 2025-08-22 14:39:43 -07:00
Awni Hannun
068a4612e9 nccl default for backend=any (#2528)
* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
2025-08-22 12:24:27 -07:00
Andrey Portnoy
5722c147de [CUDA] Update calls to cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 (#2525)
* [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13

These functions' signatures changed in CUDA 13, so we differentiate
between CUDA 13 and preceding releases at compile time.

* Mention NVIDIA in ACKNOWLEDGMENTS.md
2025-08-21 19:57:20 -07:00
Cheng
f6819a1f26 Fix warning 186-D from nvcc (#2527) 2025-08-22 10:29:55 +09:00
Awni Hannun
f93f87c802 nccl dep + default for cuda (#2526) 2025-08-21 17:57:49 -07:00
187 changed files with 8268 additions and 2881 deletions

View File

@@ -18,13 +18,14 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "16.2.0" xcode: "26.0.0"
resource_class: m2pro.medium resource_class: m4pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install name: Install
command: | command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9 brew install python@3.9
brew install doxygen brew install doxygen
python3.9 -m venv env python3.9 -m venv env
@@ -89,6 +90,7 @@ jobs:
command: | command: |
uv venv uv venv
uv pip install cmake uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v uv pip install -e ".[dev]" -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -118,7 +120,7 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "16.2.0" default: "26.0.0"
macosx_deployment_target: macosx_deployment_target:
type: string type: string
default: "" default: ""
@@ -126,12 +128,13 @@ jobs:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium resource_class: m4pro.medium
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \ HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv brew install openmpi uv
- run: - run:
@@ -196,7 +199,7 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \ uv run --no-project python -m xmlrunner discover \
@@ -227,11 +230,15 @@ jobs:
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64 rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
uv venv uv venv
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v uv pip install -e ".[dev]" -v
- run: - run:
name: Run Python tests name: Run Python tests
@@ -239,12 +246,23 @@ jobs:
source .venv/bin/activate source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run: - run:
name: CCache report name: CCache report
command: | command: |
ccache --show-stats ccache --show-stats
ccache --zero-stats ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup ccache --cleanup
- save_cache: - save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }} key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
@@ -258,7 +276,7 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "16.2.0" default: "26.0.0"
build_env: build_env:
type: string type: string
default: "" default: ""
@@ -267,7 +285,7 @@ jobs:
default: "" default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: m2pro.medium resource_class: m4pro.medium
environment: environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >> MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
@@ -275,11 +293,15 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@<< parameters.python_version >> xcodebuild -downloadComponent MetalToolchain
brew install openmpi mkdir -p ~/miniconda3
python<< parameters.python_version >> -m venv env curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
source env/bin/activate bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
pip install --upgrade pip rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.4.0
pip install --upgrade setuptools pip install --upgrade setuptools
@@ -289,19 +311,19 @@ jobs:
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate conda activate env
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate conda activate env
python setup.py clean --all python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when: - when:
@@ -311,7 +333,7 @@ jobs:
- run: - run:
name: Build common package name: Build common package
command: | command: |
source env/bin/activate conda activate env
python setup.py clean --all python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when: - when:
@@ -320,7 +342,7 @@ jobs:
- run: - run:
name: Upload package name: Upload package
command: | command: |
source env/bin/activate conda activate env
twine upload dist/* twine upload dist/*
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
@@ -393,7 +415,7 @@ jobs:
default: "" default: ""
machine: machine:
image: ubuntu-2204:current image: ubuntu-2204:current
resource_class: large resource_class: xlarge
steps: steps:
- checkout - checkout
- run: - run:
@@ -440,7 +462,7 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "14.0"] macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test: - cuda_build_and_test:
matrix: matrix:
@@ -465,68 +487,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"] xcode_version: ["26.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -568,7 +529,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
macosx_deployment_target: ["13.5", "14.0"] macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
- cuda_build_and_test: - cuda_build_and_test:
@@ -587,53 +548,7 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"] xcode_version: ["26.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
@@ -652,68 +567,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"] xcode_version: ["26.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:

View File

@@ -19,12 +19,17 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a> </a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software # Third-Party Software
MLX leverages several third-party software, listed here together with MLX leverages several third-party software, listed here together with

View File

@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -87,22 +88,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL)
message(STATUS "Metal not found. Unable to build GPU") find_library(METAL_LIB Metal)
set(MLX_BUILD_METAL OFF) find_library(FOUNDATION_LIB Foundation)
set(MLX_METAL_DEBUG OFF) find_library(QUARTZ_LIB QuartzCore)
elseif(MLX_BUILD_METAL) if(METAL_LIB)
message(STATUS "Building METAL sources") message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +111,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
@@ -140,6 +141,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif() endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32) if(WIN32)
if(MSVC) if(MSVC)
# GGUF does not build with MSVC. # GGUF does not build with MSVC.
@@ -167,7 +174,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.") message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()

View File

@@ -11,28 +11,28 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
- **Composable function transformations**: MLX supports composable function - **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization, transformations for automatic differentiation, automatic vectorization,
and computation graph optimization. and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only - **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed. materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed - **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive. slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices - **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU). (currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks - **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory. is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported Operations on MLX arrays can be performed on any of the supported
device types without transferring data. device types without transferring data.
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
``` ```text
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype( c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16") dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16"): for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf' results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source source=source,
ensure_row_contiguous=False, ensure_row_contiguous=False,
) )

View File

@@ -27,6 +27,7 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x) timeit(gelu, x)
timeit(mx.compile(nn.gelu), x) timeit(mx.compile(gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss

View File

@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python .. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True) mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn") imported_abs = mx.import_function("fun.mlxfn")
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array([-1.0]))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -107,8 +107,20 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
Note, unlike NumPy, updates to the same location are nondeterministic: .. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell .. code-block:: shell

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {{1}, {0}, {0}}; return {Shape{1}, Strides{0}, Strides{0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}}; return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -11,6 +11,8 @@ namespace mlx::core {
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector, VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General, General,
}; };
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous && (a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) { c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector; topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
b.flags()); b.flags());
} }
break; break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General: case TernaryOpType::General:
// Try to donate an input which is row_contiguous // Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||

View File

@@ -15,6 +15,7 @@
#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core { namespace mlx::core {
@@ -94,7 +95,11 @@ void* compile(
kernel_file_name = kernel_name; kernel_file_name = kernel_name;
} }
auto output_dir = std::filesystem::temp_directory_path(); auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so"; std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string(); auto shared_lib_path = (output_dir / shared_lib_name).string();

View File

@@ -88,4 +88,47 @@ void matmul<double>(
} }
} }
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -108,6 +108,9 @@ void matmul_general(
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul_dispatch<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
@@ -128,10 +131,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return; return;

View File

@@ -1,7 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -13,6 +11,35 @@ namespace mlx::core {
namespace { namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) { inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
} }
@@ -407,6 +434,231 @@ void _qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T> template <typename T>
void _bs_qmm_dispatch_typed( void _bs_qmm_dispatch_typed(
array& out, array& out,
@@ -513,41 +765,106 @@ void _bs_qmm_dispatch(
} }
} }
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace } // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> temps; auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) { auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
if (arr.flags().row_contiguous) { if (arr.flags().row_contiguous) {
return arr; return arr;
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, temps.back(), CopyType::General, s); copy_cpu(arr, arr_cpy, CopyType::General, s);
return temps.back(); encoder.add_temporary(arr_cpy);
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous(x_pre); auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre); auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out), encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x), x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w), w = array::unsafe_weak_copy(w),
@@ -558,48 +875,54 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
transpose_ = transpose_]() mutable { transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}); });
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
} }
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0]; auto& x_pre = inputs[0];
auto& w_pre = inputs[1]; auto& w_pre = inputs[1];
auto& scales_pre = inputs[2]; auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3]; auto& lhs_indices = inputs[inputs.size() - 2];
auto& lhs_indices = inputs[4]; auto& rhs_indices = inputs[inputs.size() - 1];
auto& rhs_indices = inputs[5];
std::vector<array> temps; auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous_last_dims = [s = stream(), auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) { &encoder](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1]; auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) { if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr; return arr;
} else { } else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, temps.back(), CopyType::General, s); copy_cpu(arr, arr_cpy, CopyType::General, s);
return temps.back(); encoder.add_temporary(arr_cpy);
return arr_cpy;
} }
}; };
auto x = ensure_row_contiguous_last_dims(x_pre); auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre); auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(scales); encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices); encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices); encoder.set_input_array(rhs_indices);
encoder.set_output_array(out); encoder.set_output_array(out);
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out), encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x), x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w), w = array::unsafe_weak_copy(w),
@@ -622,6 +945,18 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
bits_, bits_,
transpose_); transpose_);
}); });
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
} }
template <typename T, typename U> template <typename T, typename U>
@@ -705,7 +1040,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
} }
void fast::AffineQuantize::eval_cpu( void fast::Quantize::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) { auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -764,7 +1099,7 @@ void fast::AffineQuantize::eval_cpu(
} }
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); "[fast::Quantize::eval_cpu] Only supports floating point inputs");
} }
}); });
} }

View File

@@ -9,7 +9,7 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h // There seems to be a bug in simd/base_simd.h
// __XROS_2_0 is not defined, the expression evaluates // __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library // to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15 // higher than it should be even on macOS < 15

View File

@@ -15,6 +15,18 @@ namespace mlx::core {
namespace { namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed); std::stable_sort(st, ed, nan_aware_less<T>);
src_it.step(); src_it.step();
} }
} }
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed); std::nth_element(st, md, ed, nan_aware_less<T>);
} }
} }
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -77,7 +77,8 @@ struct Real {
struct Sigmoid { struct Sigmoid {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x)); auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
} }
SINGLE() SINGLE()
}; };

View File

@@ -170,6 +170,10 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda

View File

@@ -30,8 +30,20 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_; next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
@@ -79,7 +91,7 @@ CudaAllocator::CudaAllocator()
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8; memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
} }

View File

@@ -6,23 +6,33 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
template <typename T> namespace cg = cooperative_groups;
struct Arange {
const T start;
const T step;
__device__ T operator()(uint32_t i) const { template <typename T, typename IdxT, int N_WRITES>
return start + i * step; __global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
} }
}; } else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
}
}
} // namespace cu } // namespace cu
@@ -36,19 +46,23 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>; using OutType = cuda_type_t<CTYPE>;
CTYPE step = constexpr int N_WRITES = 16 / sizeof(OutType);
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_); dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
thrust::transform( using IdxT = std::conditional_t<large(), int64_t, int32_t>;
cu::thrust_policy(encoder.stream()), auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
thrust::counting_iterator<uint32_t>(0), encoder.add_kernel_node(
thrust::counting_iterator<uint32_t>(out.data_size()), cu::arange<OutType, IdxT, N_WRITES>,
thrust::device_pointer_cast(out.data<OutType>()), num_blocks,
cu::Arange<OutType>{ block_dims,
static_cast<OutType>(start_), static_cast<OutType>(step)}); 0,
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
}); });
} }

View File

@@ -332,9 +332,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
} }
auto kernel = mod.get_kernel(kernel_name); auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair< std::pair<
cudnnBackendDescriptorType_t, cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>> std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128); cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache; return cache;
} }
@@ -382,15 +382,13 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
} }
if (op_graph) { if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it. // Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph( auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) { if (plan) {
throw std::runtime_error("[conv] Unable to find an execution plan."); // Setup inputs and outputs.
} register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out); auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace( conv_cache().emplace(
@@ -398,6 +396,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
return; return;
} }
} }
}
// Use fallback kernel for settings not supported by cuDNN. // Use fallback kernel for settings not supported by cuDNN.
gemm_conv( gemm_conv(

View File

@@ -15,8 +15,8 @@ void copy_gpu_inplace(
int64_t offset_out, int64_t offset_out,
CopyType ctype, CopyType ctype,
const Stream& s, const Stream& s,
const std::optional<array>& dynamic_offset_in, std::optional<array> dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) { std::optional<array> dynamic_offset_out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -44,6 +44,16 @@ void copy_gpu_inplace(
strides_vec[0]); strides_vec[0]);
} else { } else {
if (dynamic_offset_in || dynamic_offset_out) { if (dynamic_offset_in || dynamic_offset_out) {
if (!dynamic_offset_in) {
dynamic_offset_in = array(0, int64);
encoder.add_temporary(*dynamic_offset_in);
}
if (!dynamic_offset_out) {
dynamic_offset_out = array(0, int64);
encoder.add_temporary(*dynamic_offset_out);
}
encoder.set_input_array(*dynamic_offset_in);
encoder.set_input_array(*dynamic_offset_out);
copy_general_dynamic( copy_general_dynamic(
encoder, encoder,
ctype, ctype,
@@ -54,8 +64,8 @@ void copy_gpu_inplace(
shape_collapsed, shape_collapsed,
strides_vec[0], strides_vec[0],
strides_vec[1], strides_vec[1],
dynamic_offset_in ? *dynamic_offset_in : array(0, int64), *dynamic_offset_in,
dynamic_offset_out ? *dynamic_offset_out : array(0, int64)); *dynamic_offset_out);
} else { } else {
copy_general( copy_general(
encoder, encoder,

View File

@@ -210,6 +210,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
Dtype dtype, Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) { cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
} }

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace { namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) { void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -27,11 +23,11 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
} }
} }
int cuda_graph_cache_size() { bool use_cuda_graphs() {
static int cache_size = []() { static bool use_graphs = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}(); }();
return cache_size; return use_graphs;
} }
} // namespace } // namespace
@@ -68,8 +64,8 @@ Device::~Device() {
void Device::make_current() { void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce // We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host. // actual calls of CUDA APIs.
static int current = 0; static thread_local int current = 0;
if (current != device_) { if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_)); CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_; current = device_;
@@ -86,11 +82,19 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current(); enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
enc.node_count_++;
return;
}
graph.end_capture(enc.stream()); graph.end_capture(enc.stream());
if (discard) { if (discard) {
return; return;
@@ -105,6 +109,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() { CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false; enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization // Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -186,35 +193,43 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d), : device_(d),
stream_(d), stream_(d),
graph_(d), graph_(d),
graph_cache_(cuda_graph_cache_size()) {} worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
} }
void CommandEncoder::set_input_array(const array& arr) { void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
} }
void CommandEncoder::set_output_array(const array& arr) { void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
active_outputs_.push_back(id); active_outputs_.push_back(id);
} }
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
@@ -230,6 +245,23 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x; kernel_params.gridDimX = grid_dim.x;
@@ -256,20 +288,38 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
return;
}
cudaGraphNode_t node; cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, 'G'});
} }
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit"); nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }
if (node_count_ > 0) { if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) { if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies( CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); graph_,
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
} }
graph_key_ += "."; graph_key_ += ".";
@@ -303,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state // Reset state
node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0; empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
@@ -315,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.commit(stream_); worker_.commit(stream_);
node_count_ = 0;
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {

View File

@@ -76,9 +76,6 @@ class CommandEncoder {
uint32_t smem_bytes, uint32_t smem_bytes,
void** params); void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child); void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) { void add_temporary(const array& arr) {
@@ -86,7 +83,7 @@ class CommandEncoder {
} }
void add_completed_handler(std::function<void()> task); void add_completed_handler(std::function<void()> task);
void maybe_commit(); int get_num_ops();
void commit(); void commit();
Device& device() { Device& device() {
@@ -101,6 +98,9 @@ class CommandEncoder {
void synchronize(); void synchronize();
private: private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode { struct GraphNode {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete; Device(const Device&) = delete;
Device& operator=(const Device&) = delete; Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls. // Make this device the current cuda device, this method is thread-safe.
void make_current(); void make_current();
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);

View File

@@ -205,9 +205,11 @@ struct Power {
if constexpr (cuda::std::is_integral_v<T>) { if constexpr (cuda::std::is_integral_v<T>) {
T res = 1; T res = 1;
// Raising an integer to a negative power is undefined // Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) { if (exp < 0) {
return 0; return 0;
} }
}
while (exp) { while (exp) {
if (exp & 1) { if (exp & 1) {
res *= base; res *= base;

View File

@@ -6,7 +6,6 @@
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x); return CastOp<SrcT, DstT>{}(x);
} }
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -257,8 +257,8 @@ struct Round {
struct Sigmoid { struct Sigmoid {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x))); T y = 1 / (1 + exp(abs(x)));
return (x < 0) ? 1 - y : y; return (x < 0) ? y : 1 - y;
} }
}; };

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both // This file must not include any host-only code, utilities that work under both
// host and device can be put here. // host and device can be put here.
// //
// See more about the requirements at: // See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
} }
}; };
// CUDA 11 does not have host side arithmatic operators for half types. // CUDA 11 does not have host side arithmetic operators for half types.
template <typename T> template <typename T>
struct Limits< struct Limits<
T, T,

View File

@@ -2,30 +2,36 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <cassert> #include <cassert>
namespace mlx::core { namespace mlx::core::distributed {
namespace distributed {
void AllReduce::eval_gpu( void AllReduce::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto& input = inputs[0]; auto set_input_output =
auto& output = outputs[0]; [s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
if (input.is_donatable()) {
output.copy_shared_buffer(input);
} else {
output.set_data(allocator::malloc(output.nbytes()));
}
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(output); encoder.set_output_array(output);
@@ -47,5 +53,4 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported."); "Only all reduce sum, max, and min are supported.");
} }
} }
} // namespace distributed } // namespace mlx::core::distributed
} // namespace mlx::core

View File

@@ -5,18 +5,24 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu { namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() { bool is_available() {
return true; return true;
} }
void new_stream(Stream s) { void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last. // Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr); cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
} }
@@ -34,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
} }
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
// Except for the donated one. // Except for the donated one.
@@ -45,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
encoder.add_temporary(s); encoder.add_temporary(s);
} }
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
} }
void finalize(Stream s) { void finalize(Stream s) {

View File

@@ -3,10 +3,12 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
@@ -17,104 +19,180 @@ namespace cu {
// CudaEvent implementations // CudaEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII. namespace {
class CudaEventHandle {
// Manage cached cudaEvent_t objects.
class CudaEventPool {
public: public:
CudaEventHandle() { CudaEventHandle create(Device& d, int flags) {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags( if (!on_creation_thread()) {
&event_, cudaEventDisableTiming | cudaEventBlockingSync)); return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
} }
~CudaEventHandle() { void release(CudaEventHandle event) {
CHECK_CUDA_ERROR(cudaEventDestroy(event_)); if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
} }
cache_for(event.device, event.flags).push_back(std::move(event));
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
} }
private: private:
cudaEvent_t event_; std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
}; };
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {} CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() { void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait"); nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) { event_.device.make_current();
throw std::runtime_error("Should not wait on a CudaEvent before record."); cudaEventSynchronize(event_);
}
cudaEventSynchronize(*event_);
} }
void CudaEvent::wait(cudaStream_t stream) { void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) { event_.device.make_current();
throw std::runtime_error("Should not wait on a CudaEvent before record."); cudaStreamWaitEvent(stream, event_);
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
} }
void CudaEvent::record(cudaStream_t stream) { void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream); event_.device.make_current();
recorded_ = true; cudaEventRecord(event_, stream);
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
} }
bool CudaEvent::completed() const { bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess; // Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
} }
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations // AtomicEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
uint64_t current; uint64_t current;
while ((current = ac->load()) < value) { while ((current = ac->load()) < value) {
ac->wait(current); ac->wait(current);
} }
} }
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
ac->store(value); ac->store(value);
ac->notify_all(); ac->notify_all();
} }
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value); event_wait(ac, value);
} }
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) { AtomicEvent::AtomicEvent() {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>( buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr); allocator().free(*ptr);
@@ -123,17 +201,17 @@ SharedEvent::SharedEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0; *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
} }
void SharedEvent::wait(uint64_t value) { void AtomicEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::AtomicEvent::wait");
event_wait(to_atomic(buf_), value); event_wait(atomic(), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void AtomicEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
@@ -144,17 +222,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
} }
} }
void SharedEvent::signal(uint64_t value) { void AtomicEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::AtomicEvent::signal");
event_signal(to_atomic(buf_), value); event_signal(atomic(), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void AtomicEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified. // the atomic in CPU sometimes does not get GPU notified.
@@ -168,14 +246,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool AtomicEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
return to_atomic(buf_)->load() >= value; return atomic()->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t AtomicEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::AtomicEvent::value");
return to_atomic(buf_)->load(); return atomic()->load();
} }
} // namespace cu } // namespace cu
@@ -188,14 +266,14 @@ namespace {
struct EventImpl { struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have // CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases: // to fallback to AtomicEvent in following cases:
// 1. the event is used to wait/signal a cpu stream; // 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified. // 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda; std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared; std::unique_ptr<cu::AtomicEvent> atomic;
bool is_created() const { bool is_created() const {
return cuda || shared; return cuda || atomic;
} }
void ensure_created(Stream s, uint64_t signal_value) { void ensure_created(Stream s, uint64_t signal_value) {
@@ -203,10 +281,10 @@ struct EventImpl {
return; return;
} }
if (s.device == mlx::core::Device::cpu || signal_value > 1) { if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent"); nvtx3::mark("Using slow AtomicEvent");
shared = std::make_unique<cu::SharedEvent>(); atomic = std::make_unique<cu::AtomicEvent>();
} else { } else {
cuda = std::make_unique<cu::CudaEvent>(); cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
} }
} }
}; };
@@ -225,7 +303,7 @@ void Event::wait() {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(); event->cuda->wait();
} else { } else {
event->shared->wait(value()); event->atomic->wait(value());
} }
} }
@@ -236,7 +314,7 @@ void Event::wait(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(s); event->cuda->wait(s);
} else { } else {
event->shared->wait(s, value()); event->atomic->wait(s, value());
} }
} }
@@ -247,7 +325,7 @@ void Event::signal(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->record(s); event->cuda->record(s);
} else { } else {
event->shared->signal(s, value()); event->atomic->signal(s, value());
} }
} }
@@ -258,9 +336,9 @@ bool Event::is_signaled() const {
} }
if (event->cuda) { if (event->cuda) {
assert(value() == 1); assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed(); return event->cuda->is_signaled();
} else { } else {
return event->shared->is_signaled(value()); return event->atomic->is_signaled(value());
} }
} }

View File

@@ -3,49 +3,60 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda/atomic> #include <cuda/atomic>
#include <memory>
namespace mlx::core::cu { namespace mlx::core::cu {
class CudaEventHandle; class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait // Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream. // on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent { class CudaEvent {
public: public:
CudaEvent(); CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
void wait(); void wait();
void wait(cudaStream_t stream); void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream); void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method // Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called. // returns true if record() has not been called.
bool completed() const; bool completed() const;
bool recorded() const { // Internal: make sure event pool is initialized.
return recorded_; static void init_pool();
}
private: private:
bool recorded_{false}; CudaEventHandle event_;
std::shared_ptr<CudaEventHandle> event_;
}; };
// Event that can synchronize between CPU and GPU. It is much slower than // Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible. // CudaEvent so the latter should always be preferred when possible.
class SharedEvent { class AtomicEvent {
public: public:
using Atomic = cuda::atomic<uint64_t>; using Atomic = cuda::atomic<uint64_t>;
SharedEvent(); AtomicEvent();
void wait(uint64_t value); void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value); void wait(cudaStream_t stream, uint64_t value);
@@ -57,7 +68,11 @@ class SharedEvent {
uint64_t value() const; uint64_t value() const;
private: private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_; Atomic* atomic() const {
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -7,7 +7,7 @@ namespace mlx::core {
struct FenceImpl { struct FenceImpl {
uint32_t count; uint32_t count;
cu::SharedEvent event; cu::AtomicEvent event;
}; };
Fence::Fence(Stream s) { Fence::Fence(Stream s) {

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F; : CUBLAS_COMPUTE_32F;
case float64: case float64:
case complex64:
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -85,10 +87,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count, int32_t batch_count,
int64_t batch_stride) { int64_t batch_stride) {
cublasLtMatrixLayout_t desc; cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) { if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, desc,
@@ -126,37 +128,47 @@ CublasGemm::CublasGemm(
N_(b_cols) { N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype); scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) { if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F; scale_type_ = CUDA_R_32F;
} }
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type)); &matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, &pointer_mode,
sizeof(int32_t))); sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA, CUBLASLT_MATMUL_DESC_TRANSA,
&op, &a_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB, CUBLASLT_MATMUL_DESC_TRANSB,
&op, &b_op,
sizeof(cublasOperation_t))); sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout( a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout( b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
} }
CublasGemm::CublasGemm( CublasGemm::CublasGemm(
@@ -191,7 +203,7 @@ CublasGemm::CublasGemm(
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cublas_type(dtype); auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
} }
CublasGemm::~CublasGemm() { CublasGemm::~CublasGemm() {
@@ -213,14 +225,30 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout( out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype), dtype_to_cublas_type(dtype),
rows,
cols, cols,
rows,
transposed, transposed,
ld, ld,
batch_count, batch_count,
batch_stride); batch_stride);
} }
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run( void CublasGemm::run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -228,11 +256,19 @@ void CublasGemm::run(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) { if (batch_count / batch_shape.back() > 1) {
run_batched( run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return; return;
} }
@@ -240,7 +276,13 @@ void CublasGemm::run(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr); execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
} }
void CublasGemm::run( void CublasGemm::run(
@@ -313,6 +355,16 @@ void CublasGemm::execute(
} }
} }
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned // Ensure workspace is 256-byte aligned
@@ -329,12 +381,12 @@ void CublasGemm::execute(
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_, handle_,
matmul_desc_, matmul_desc_,
&alpha, alpha_ptr,
a, b, // a and b are swapped
a_desc_, a_desc_,
b, a,
b_desc_, b_desc_,
&beta, beta_ptr,
c ? c : out, c ? c : out,
c ? c_desc_ : out_desc_, c ? c_desc_ : out_desc_,
out, out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count, int32_t batch_count,
int64_t batch_stride); int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
array& out, array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha = 1.0f);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides); const Strides& b_batch_strides,
float alpha);
void run_batched( void run_batched(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -111,6 +115,7 @@ class CublasGemm {
uint64_t M_; uint64_t M_;
uint64_t N_; uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc, a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc, b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr); nullptr,
alpha);
a_it.step(); a_it.step();
b_it.step(); b_it.step();
} }

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b, const array& b,
const Shape& batch_shape, const Shape& batch_shape,
const Strides& a_batch_strides, const Strides& a_batch_strides,
const Strides& b_batch_strides) { const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers), reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers), reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers), reinterpret_cast<void*>(b_pointers),
nullptr); nullptr,
alpha);
} }
void CublasGemm::run_batched( void CublasGemm::run_batched(

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8; static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread> template <typename T, int rows_per_block, int n_per_thread>
__device__ void __device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y; int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) { if (row < rows) {
float sum = 0.0f; using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols; for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) { col += (WARP_SIZE * n_per_thread)) {
auto local_mat = auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0); auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll #pragma unroll
for (int j = 0; j < n_per_thread; ++j) { for (int j = 0; j < n_per_thread; ++j) {
sum += sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
} }
} }
sum = cg::reduce(warp, sum, cg::plus<float>{}); sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum); out[row] = static_cast<T>(sum);
} }
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) { dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block}; dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat; const DataType* mat;

View File

@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim()); args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_); args.append_ndim(slice_sizes_);
args.append(slice_size); args.append(slice_size);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end())); args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape()); args.append_ndim(out.shape());
args.append_ndim(out.strides()); args.append_ndim(out.strides());
args.append<int32_t>(out.ndim()); args.append<int32_t>(out.ndim());
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end())); args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim); append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(

View File

@@ -67,10 +67,12 @@ const std::string& cccl_dir() {
return path.string(); return path.string();
} }
// Finally check the environment variable. // Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR"); if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
path = env;
if (!path.empty() && std::filesystem::exists(path)) { if (!path.empty() && std::filesystem::exists(path)) {
return path.string(); return path.string();
} }
}
return std::string(); return std::string();
}(); }();
return dir; return dir;
@@ -97,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() {
return cache; return cache;
} }
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
@@ -107,7 +133,7 @@ bool read_cached_ptx(
return false; return false;
} }
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = get_ptx_path(cache_dir, module_name);
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) { if (error) {
@@ -120,7 +146,7 @@ bool read_cached_ptx(
ptx.resize(ptx_size); ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size); ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
@@ -142,16 +168,26 @@ void write_cached_ptx(
return; return;
} }
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); auto ptx_path = get_ptx_path(cache_dir, module_name);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
} }
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl; txt_file << name << "\t" << mangled << std::endl;
} }
std::ofstream source_file(cache_dir / (module_name + ".cu")); // Write the generated code
std::ofstream source_file(ptx_path.replace_extension(".cu"));
source_file << source_code; source_file << source_code;
} }
@@ -295,7 +331,8 @@ void load_module(
const std::string& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_, CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) { std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -312,7 +349,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false); kernels[name] = std::make_tuple(kernel, false, 0);
} }
} }
@@ -356,7 +393,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel( std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) { std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
@@ -367,14 +404,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only // If it is the first time we run this kernel then configure it. Do it only
// once! // once!
if (!it->second.second) { auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) { if (configure_kernel) {
configure_kernel(it->second.first); configure_kernel(kernel);
} }
it->second.second = true; std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
} }
return it->second.first; return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -46,6 +46,11 @@ struct KernelArgs {
append_ptr(std::get<SmallVector<T>>(storage_.back()).data()); append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
} }
template <typename T>
void append(const std::vector<T>& vec) {
append(SmallVector<T>(vec.begin(), vec.end()));
}
// Make sure the arg is copied to an array with size of NDIM. // Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T> template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) { void append_ndim(SmallVector<T> vec) {
@@ -94,10 +99,13 @@ class JitModule {
CUfunction get_kernel( CUfunction get_kernel(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr); std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_; std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread) { int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread); size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024; uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks; dim3 num_blocks;
if (large) { if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference // This file includes host-only utilities for writing CUDA kernels, the
// from backend/cuda/device/utils.cuh is that the latter file only include // difference from backend/cuda/device/utils.cuh is that the latter file only
// device-only code. // include device-only code.
#pragma once #pragma once
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
size_t divisor); size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2); std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|, // Get the num_blocks and block_dims assuming each thread handles
// assuming each thread handles |work_per_thread| elements of |arr|. // |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args( std::tuple<dim3, uint> get_launch_args(
size_t size, size_t size,
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread = 1); int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint> inline std::tuple<dim3, uint> get_launch_args(
get_launch_args(const array& arr, bool large, int work_per_thread = 1) { const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args( return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread); arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,11 +2,15 @@
#pragma once #pragma once
#include "mlx/utils.h"
#include <cstring> #include <cstring>
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <fmt/format.h>
namespace mlx::core { namespace mlx::core {
template < template <
@@ -27,6 +31,14 @@ class LRUCache {
} }
} }
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const { size_t size() const {
return map_.size(); return map_.size();
} }
@@ -76,6 +88,14 @@ class LRUCache {
return {it->second, false}; return {it->second, false};
} }
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value)); vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin(); map_[key] = vlist_.begin();
@@ -106,6 +126,9 @@ class LRUCache {
} }
} }
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_; list_type vlist_;
map_type map_; map_type map_;
size_t capacity_; size_t capacity_;

View File

@@ -11,6 +11,7 @@
#include <numeric> #include <numeric>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
std::tuple<bool, int64_t, array> std::tuple<bool, int64_t, array>
@@ -28,6 +29,80 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} }
} }
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
if (a.dtype() == complex64) {
throw std::runtime_error(
"[gemm_and_bias] complex64 bias epilogue isnt supported in cublasLtMatmul.");
}
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace } // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +123,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
int K = a_pre.shape(-1); int K = a_pre.shape(-1);
@@ -60,58 +132,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
///////////////////////////////////////////////////////////////////////////// gemm_and_bias(
// Check and collapse batch dimensions encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +158,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc; int64_t ldc;
{ {
auto stx = c.strides()[c.ndim() - 2]; auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +222,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt // Invoke cublasLt with AddMM settings
CublasGemm gemm( CublasGemm gemm(
cu::device(s.device), cu::device(s.device),

View File

@@ -24,8 +24,6 @@ namespace mlx::core {
} }
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT) NO_GPU(FFT)
NO_GPU(GatherMM) NO_GPU(GatherMM)
NO_GPU(GatherQMM) NO_GPU(GatherQMM)

View File

@@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
} // namespace } // namespace
void fast::AffineQuantize::eval_gpu( void fast::Quantize::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu"); nvtx3::scoped_range r("Quantize::eval_gpu");
auto& s = stream(); auto& s = stream();
auto& d = cu::device(s.device); auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s); auto& enc = d.get_command_encoder(s);

View File

@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
} }
} }
template <typename T, typename U, typename Op, int N_READS = 4>
__global__ void col_reduce_small(
const T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
size_t total) {
Op op;
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
const auto idx = grid.thread_rank() * N_READS;
const auto before_axis = idx / args.reduction_stride;
const auto after_axis = idx % args.reduction_stride;
const auto offset =
before_axis * args.reduction_stride * args.reduction_size + after_axis;
if (idx >= total) {
return;
}
in += offset;
out += idx;
AlignedVector<U, N_READS> accumulator;
for (int i = 0; i < N_READS; i++) {
accumulator[i] = ReduceInit<Op, T>::value();
}
for (int i = 0; i < args.reduction_size; i++) {
auto values = load_vector<N_READS>(in, 0);
for (int j = 0; j < N_READS; j++) {
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
}
in += args.reduction_stride;
}
store_vector(out, 0, accumulator);
}
} // namespace cu } // namespace cu
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
@@ -206,7 +247,7 @@ void col_reduce_looped(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::ColReduceArgs args) { const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -230,12 +271,55 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, blocks, 0, indata, out.data<U>(), args); kernel,
grid,
blocks,
0,
indata,
out.data<U>(),
static_cast<cu::ColReduceArgs>(args));
}); });
}); });
}); });
} }
void col_reduce_small(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 16 / sizeof(T);
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
block,
0,
in.data<T>(),
out.data<U>(),
static_cast<cu::ColReduceArgs>(args),
out.size());
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -258,6 +342,13 @@ void col_reduce(
// Make the args struct to help route to the best kernel // Make the args struct to help route to the best kernel
cu::ColReduceArgs args(in, plan, axes); cu::ColReduceArgs args(in, plan, axes);
// Small col reduce with a single or contiguous reduction axis
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
args.reduction_stride % (16 / in.itemsize()) == 0) {
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -7,8 +7,6 @@
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
namespace mlx::core { namespace mlx::core {
@@ -83,7 +81,8 @@ struct RowReduceArgs {
}; };
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1> template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { __global__ void
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
const U init = cu::ReduceInit<ReduceOp, T>::value(); const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op; ReduceOp op;
T vals[M][N]; AlignedVector<T, N> vals[M];
U accs[M]; AlignedVector<U, M> accs;
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
accs[i] = init; accs[i] = init;
} }
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M)); min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
const size_t full_blocks = size / (block.size() * N); const size_t full_blocks = size / (block.size() * N);
const size_t final_offset = full_blocks * (block.size() * N); const size_t final_offset = full_blocks * (block.size() * N);
in += start_row * size; in += start_row * size + block.thread_rank() * N;
out += start_row; out += start_row;
if (size % N == 0) {
for (size_t r = 0; r < full_blocks; r++) { for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) { for (int k = 0; k < M; k++) {
cub::LoadDirectBlockedVectorized<T, N>( vals[k] = load_vector<N>(in + k * size, 0);
block.thread_rank(), }
in + k * size + r * (block.size() * N), for (int k = 0; k < M; k++) {
vals[k]);
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
} }
}
} else { in += block.size() * N;
for (size_t r = 0; r < full_blocks; r++) {
for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked(
block.thread_rank(),
in + k * size + r * (block.size() * N),
vals[k]);
for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
}
}
}
} }
if (final_offset < size) { if (final_offset < size) {
for (int k = 0; k < M; k++) { for (int k = 0; k < M; k++) {
cub::LoadDirectBlocked( for (int i = 0; i < N; i++) {
block.thread_rank(), vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
in + k * size + final_offset, ? in[k * size + i]
vals[k], : cast_to<T>(init);
size, }
cast_to<T>(init)); }
for (int k = 0; k < M; k++) {
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
accs[k] = op(accs[k], cast_to<U>(vals[k][j])); accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
} }
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
} }
__shared__ U shared_accumulators[32 * M]; __shared__ U shared_accumulators[32 * M];
block_reduce(block, warp, accs, shared_accumulators, op, init); block_reduce(block, warp, accs.val, shared_accumulators, op, init);
if (block.thread_rank() == 0) { if (block.thread_rank() == 0) {
if (grid.block_rank() * M + M <= n_rows) { if (grid.block_rank() * M + M <= n_rows) {
for (int i = 0; i < M; i++) { store_vector(out, 0, accs);
out[i] = accs[i];
}
} else { } else {
short offset = grid.block_rank() * M + M - n_rows; short offset = grid.block_rank() * M + M - n_rows;
for (int i = offset; i < M; i++) { for (int i = offset; i < M; i++) {
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
} }
} }
template < template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
typename T,
typename U,
typename Op,
int NDIM,
int BLOCK_DIM,
int N_READS = 4>
__global__ void row_reduce_looped( __global__ void row_reduce_looped(
T* in, const T* in,
U* out, U* out,
size_t out_size,
const __grid_constant__ RowReduceArgs args) { const __grid_constant__ RowReduceArgs args) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
@@ -185,37 +163,61 @@ __global__ void row_reduce_looped(
U init = ReduceInit<Op, T>::value(); U init = ReduceInit<Op, T>::value();
total[0] = init; total[0] = init;
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim); LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS); const size_t full_blocks = args.row_size / (block.size() * N_READS);
size_t final_offset = full_blocks * BLOCK_DIM * N_READS; const size_t final_offset = full_blocks * (block.size() * N_READS);
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
in += block.thread_rank() * N_READS;
// Unaligned reduce
if (final_offset < args.row_size) {
bool mask[N_READS];
for (int i = 0; i < N_READS; i++) {
mask[i] =
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
}
for (size_t n = 0; n < args.non_row_reductions; n++) { for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) { for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
{
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized<T, N_READS>( for (int i = 0; i < N_READS; i++) {
block.thread_rank(), vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
in + loop.location() + r * BLOCK_DIM * N_READS, }
vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i])); total[0] = op(total[0], cast_to<U>(vals[i]));
} }
} }
if (final_offset < args.row_size) {
T vals[N_READS];
cub::LoadDirectBlocked(
block.thread_rank(),
in + loop.location() + final_offset,
vals,
args.row_size - final_offset,
cast_to<T>(init));
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
}
// TODO: Maybe block.sync() here?
loop.next(args.reduce_shape.data(), args.reduce_strides.data()); loop.next(args.reduce_shape.data(), args.reduce_strides.data());
} }
}
// Aligned case
else {
for (size_t n = 0; n < args.non_row_reductions; n++) {
const T* inlocal = in + loop.location();
for (size_t r = 0; r < full_blocks; r++) {
auto vals = load_vector<N_READS>(inlocal, 0);
for (int i = 0; i < N_READS; i++) {
total[0] = op(total[0], cast_to<U>(vals[i]));
}
inlocal += block.size() * N_READS;
}
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
}
}
__shared__ U shared_accumulators[32]; __shared__ U shared_accumulators[32];
block_reduce(block, warp, total, shared_accumulators, op, init); block_reduce(block, warp, total, shared_accumulators, op, init);
@@ -234,8 +236,6 @@ void row_reduce_simple(
Reduce::ReduceType reduce_type, Reduce::ReduceType reduce_type,
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan) { const ReductionPlan& plan) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to avoid elem_to_loc in the // Allocate data for the output using in's layout to avoid elem_to_loc in the
// kernel. // kernel.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -250,14 +250,15 @@ void row_reduce_simple(
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) constexpr int N_READS = 16 / sizeof(T);
T* indata = const_cast<T*>(in.data<T>());
// Calculate the grid and block dims // Calculate the grid and block dims
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS; size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
int threads = std::min(1024UL, reductions); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
@@ -267,6 +268,7 @@ void row_reduce_simple(
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>; kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
} }
T* indata = const_cast<T*>(in.data<T>());
int size = plan.shape.back(); int size = plan.shape.back();
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size); kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
@@ -282,8 +284,6 @@ void row_reduce_looped(
const std::vector<int>& axes, const std::vector<int>& axes,
const ReductionPlan& plan, const ReductionPlan& plan,
cu::RowReduceArgs args) { cu::RowReduceArgs args) {
constexpr int N_READS = 8;
// Allocate data for the output using in's layout to access them as // Allocate data for the output using in's layout to access them as
// contiguously as possible. // contiguously as possible.
allocate_same_layout(out, in, axes); allocate_same_layout(out, in, axes);
@@ -295,34 +295,27 @@ void row_reduce_looped(
using OP = MLX_GET_TYPE(reduce_type_tag); using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>()); constexpr int N_READS = 16 / sizeof(T);
// Calculate the grid and block dims // Calculate the grid and block dims
args.sort_access_pattern(in, axes); args.sort_access_pattern(in, axes);
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
size_t reductions = (args.row_size + N_READS - 1) / N_READS; size_t reductions = (args.row_size + N_READS - 1) / N_READS;
int threads = std::min(1024UL, reductions); int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; warps /= 4;
warps = std::max(std::min(warps, 32), 1);
int threads = warps * WARP_SIZE;
dim3 block(threads, 1, 1); dim3 block(threads, 1, 1);
// Pick the kernel // Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>; auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
dispatch_block_dim(threads, [&](auto threads_constant) { kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
kernel = cu::row_reduce_looped<
T,
U,
OP,
reduce_ndim.value,
threads_constant.value,
N_READS>;
block.x = threads_constant.value;
});
}); });
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args); kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
}); });
}); });
} }

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl( __device__ void rope_impl(
const T* in, const T* in,
T* out, T* out,
int offset, const int* offset,
float inv_freq, float inv_freq,
float scale, float scale,
const cuda::std::array<int64_t, 3> strides, const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides, const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 pos, uint3 pos,
uint3 dims) { uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2; size_t out_index_1, out_index_2;
if (traditional) { if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1; out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2]; in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0]; mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2]; out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2]; in_index_2 = in_index_1 + dims.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims) { uint3 dims) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x, blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base, float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides, const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides, const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch, int64_t offset_stride,
int n_head,
uint3 dims, uint3 dims,
int64_t freq_stride) { int64_t freq_stride) {
uint3 pos = make_uint3( uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>( rope_impl<T, traditional, forward>(
in, in,
out, out,
*offset, offset,
inv_freq, inv_freq,
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
dims); dims);
} }
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1]; auto& offset = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides; cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides; cuda::std::array<int64_t, 3> out_strides;
bool donated = false; bool donated = false;
int ndim = in.ndim(); int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--; dispatch_ndim--;
} }
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then // We apply rope to less that the whole vector so copy to output and then
// apply in-place. // apply in-place.
if (dims_ < in.shape(-1)) { if (dims_ < D) {
donated = true; donated = true;
auto ctype = auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below // Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1)); bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3; bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>; cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) { } else if (single) {
auto kernel = auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>; cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>; cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>; auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims = int n_per_thread = 4;
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
dims.z = (dims.z + 3) / 4; uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
grid, grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_), std::log2(base_),
strides, strides,
out_strides, out_strides,
in.size() / mat_size, offset_stride,
N,
dims); dims);
} }
}); });

View File

@@ -4,7 +4,6 @@
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
@@ -46,6 +45,7 @@ __global__ void kernel_sdpav_1pass(
const T* K, const T* K,
const T* V, const T* V,
T* O, T* O,
const T* sinks,
__grid_constant__ const AttnParams params) { __grid_constant__ const AttnParams params) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
@@ -65,7 +65,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN]; __shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN]; __shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f; const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block); auto warp = cg::tiled_partition<32>(block);
@@ -108,8 +108,12 @@ __global__ void kernel_sdpav_1pass(
o[i] = 0.f; o[i] = 0.f;
} }
U max_score = -INFINITY; U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) { for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -167,7 +171,7 @@ __global__ void kernel_sdpav_1pass(
U factor = exp2f(max_score - new_max); U factor = exp2f(max_score - new_max);
sum_exp_score = sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>()); cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score); sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs // Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL PRAGMA_LOOP_UNROLL
@@ -193,6 +197,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q, const T* Q,
const T* K, const T* K,
const T* V, const T* V,
const T* sinks,
float* partials, float* partials,
float* sums, float* sums,
float* maxs, float* maxs,
@@ -268,8 +273,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f; o[i] = 0.f;
} }
U max_score = -1e9; U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f; U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key // For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) { for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -410,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2(
U new_max = cg::reduce(warp, max_score, cg::greater<U>()); U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max); U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>()); U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score); sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
@@ -463,10 +472,14 @@ void sdpa_vector_1pass_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o); encoder.set_output_array(o);
cu::AttnParams params{ cu::AttnParams params{
@@ -489,7 +502,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1); dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) { dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) { dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) { dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -504,6 +517,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
o.data<DataType>(), o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params); params);
}); });
}); });
@@ -518,7 +532,8 @@ void sdpa_vector_2pass_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{ cu::AttnParams params{
/* int B = */ q.shape(0), /* int B = */ q.shape(0),
/* int H = */ q.shape(1), /* int H = */ q.shape(1),
@@ -559,7 +574,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs); encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) { dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) { dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) { dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -570,6 +585,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q); encoder.set_input_array(q);
encoder.set_input_array(k); encoder.set_input_array(k);
encoder.set_input_array(v); encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.set_output_array(sums); encoder.set_output_array(sums);
encoder.set_output_array(maxs); encoder.set_output_array(maxs);
@@ -585,6 +604,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(), q.data<DataType>(),
k.data<DataType>(), k.data<DataType>(),
v.data<DataType>(), v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(), intermediate.data<float>(),
sums.data<float>(), sums.data<float>(),
maxs.data<float>(), maxs.data<float>(),
@@ -627,15 +647,16 @@ void sdpa_vector_fallback(
const array& v, const array& v,
const float scale, const float scale,
array& o, array& o,
bool do_causal_ = false) { bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2); int kL = k.shape(2);
if (kL > 1024) { if (kL > 1024) {
return sdpa_vector_2pass_fallback( return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_); s, encoder, q, k, v, scale, o, do_causal, sinks);
} else { } else {
return sdpa_vector_1pass_fallback( return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_); s, encoder, q, k, v, scale, o, do_causal, sinks);
} }
} }
@@ -691,7 +712,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(3); copies.reserve(inputs.size());
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
@@ -703,6 +724,16 @@ void ScaledDotProductAttention::eval_gpu(
} }
}; };
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) < 4) { if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) { auto q_copy_unless = [](const array& arr) {
@@ -740,10 +771,6 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre); const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre); const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible // Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q); o.copy_shared_buffer(q);
@@ -752,22 +779,26 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3); int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH; int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL; int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{ array::Flags flags{
/* bool contiguous = */ 1, /* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1, /* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0, /* bool col_contiguous = */ o.size() == o.shape(3),
}; };
o.set_data( o.set_data(
allocator::malloc(o.nbytes()), allocator::malloc(o.nbytes()),
data_size, o.size(),
{str_oB, str_oH, str_oL, str_oD}, {str_oB, str_oH, str_oL, str_oD},
flags); flags);
} }
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_); for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
} }
// Full attention mode should never reach here // Full attention mode should never reach here

View File

@@ -1,8 +1,11 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/slicing.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/gpu/slicing.h"
#include "mlx/dtype_utils.h"
#include <numeric> #include <numeric>
@@ -27,8 +30,7 @@ void concatenate_gpu(
flags.row_contiguous = false; flags.row_contiguous = false;
flags.col_contiguous = false; flags.col_contiguous = false;
flags.contiguous = false; flags.contiguous = false;
// TODO: Handle concurrent outputs: auto concurrent = cu::get_command_encoder(s).concurrent_context();
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i]; size_t data_offset = strides[axis] * sizes[i];
@@ -38,4 +40,71 @@ void concatenate_gpu(
} }
} }
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s) {
Dtype dtype = indices.dtype();
int nidx = axes.size();
std::string module_name =
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
std::string kernel_name = fmt::format(
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
dtype_to_cuda_type(dtype),
nidx);
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::string source = R"(
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
template <typename T, int NIDX>
__global__ void compute_dynamic_offset(
const T* indices,
int64_t* offset,
const __grid_constant__ Strides strides,
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
int64_t acc = 0;
#pragma unroll
for (int i = 0; i < NIDX; ++i) {
acc += indices[i] * strides[axes[i]];
}
*offset = acc;
}
} // namespace mlx::core::cu
)";
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
cu::KernelArgs args;
args.append(indices);
args.append(offset);
args.append_ndim(strides);
args.append(axes);
auto kernel = mod.get_kernel(kernel_name);
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
return offset;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -9,7 +9,7 @@
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <cub/device/device_segmented_sort.cuh> #include <cub/device/device_segmented_radix_sort.cuh>
#include <cassert> #include <cassert>
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.add_temporary(discard); encoder.add_temporary(discard);
size_t size; size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, nullptr,
size, size,
in.data<Type>(), in.data<Type>(),
@@ -90,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -104,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::device_pointer_cast(indices.data<uint32_t>()), thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)}); ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(), temp.data<void>(),
size, size,
in.data<Type>(), in.data<Type>(),
@@ -115,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
} else { } else {
size_t size; size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr, nullptr,
size, size,
in.data<Type>(), in.data<Type>(),
@@ -127,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8); array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -134,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
// Start capturing after allocations // Start capturing after allocations
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys( CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(), temp.data<void>(),
size, size,
in.data<Type>(), in.data<Type>(),
@@ -143,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort, in.data_size() / nsort,
offsets, offsets,
offsets + 1, offsets + 1,
0,
sizeof(Type) * 8,
stream)); stream));
} }
} else { } else {

View File

@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::VectorVectorVector ||
topt == TernaryOpType::ScalarScalarScalar) {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} else {
dispatch_bool( dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
ndim); ndim);
} }
}); });
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(DType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large(), N_READS);
encoder.add_kernel_node(
cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size());
});
} }
}); });
} }

View File

@@ -1,284 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_v(const In* in, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(in[i]);
}
} else {
auto in_vec = load_vector<N_READS>(in, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void unary_g(
const In* in,
Out* out,
IdxT size_rest,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides strides,
int ndim) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
IdxT index_rest =
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
if (index_rest >= size_rest) {
return;
}
auto shape_x = shape[ndim - 1];
auto stride_x = strides[ndim - 1];
IdxT index_x =
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
auto idx =
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
auto in_vec =
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec[i] = Op{}(in_vec[i]);
}
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
}
template <typename Op, typename In, typename Out>
constexpr bool supports_unary_op() {
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
std::is_same_v<Op, Sigmoid>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseInvert>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, Conjugate>) {
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
}
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
std::is_same_v<Op, Tanh>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
}
if (std::is_same_v<Op, LogicalNot>) {
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void unary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
auto& in = inputs[0];
if (in.size() == 0) {
return;
}
bool contig = in.flags().contiguous;
bool large;
if (!contig) {
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
dispatch_bool(large, [&](auto large) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
if (contig) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
constexpr int N_READS = 16 / sizeof(OutType);
auto [num_blocks, block_dims] = get_launch_args(
out.data_size(), out.shape(), out.strides(), large, N_READS);
encoder.add_kernel_node(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks,
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
out.data_size());
} else {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [shape, strides] = collapse_contiguous_dims(in);
auto ndim = shape.size();
int work_per_thread = 1;
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
work_per_thread = 4;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
encoder.add_kernel_node(
kernel,
{num_blocks_x, num_blocks_y},
block_dims,
0,
in.data<InType>(),
out.data<OutType>(),
rest,
const_param(shape),
const_param(strides),
ndim);
}
});
} else {
throw std::runtime_error(fmt::format(
"Can not do unary op {} on input of {} with output of {}.",
op,
dtype_to_string(in.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void unary_op_gpu(
const std::vector<array>& inputs,
array& out,
const char* op,
const Stream& s) {
set_unary_output_data(inputs[0], out);
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define UNARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
unary_op_gpu<cu::func>(inputs, out, name(), s); \
}
UNARY_GPU(Abs)
UNARY_GPU(ArcCos)
UNARY_GPU(ArcCosh)
UNARY_GPU(ArcSin)
UNARY_GPU(ArcSinh)
UNARY_GPU(ArcTan)
UNARY_GPU(ArcTanh)
UNARY_GPU(BitwiseInvert)
UNARY_GPU(Ceil)
UNARY_GPU(Conjugate)
UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Floor)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)
UNARY_GPU(Sinh)
UNARY_GPU(Square)
UNARY_GPU(Tan)
UNARY_GPU(Tanh)
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Log::eval_gpu");
auto& s = out.primitive().stream();
switch (base_) {
case Base::e:
unary_op_gpu<cu::Log>(inputs, out, name(), s);
break;
case Base::two:
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
break;
case Base::ten:
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
break;
}
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Round::eval_gpu");
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto& s = out.primitive().stream();
if (issubdtype(in.dtype(), inexact)) {
unary_op_gpu<cu::Round>(inputs, out, name(), s);
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Sort::eval_gpu");
auto& s = out.primitive().stream();
if (recip_) {
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
} else {
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
}
}
} // namespace mlx::core

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file include utilies that are used by C++ code (i.e. .cpp files). // This file include utilities that are used by C++ code (i.e. .cpp files).
#pragma once #pragma once
@@ -12,6 +12,7 @@ namespace mlx::core {
namespace cu { namespace cu {
class Device; class Device;
} }
struct Dtype; struct Dtype;
@@ -86,4 +87,17 @@ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
explicit CudaStream(cu::Device& device); explicit CudaStream(cu::Device& device);
}; };
template <typename T>
inline uint max_occupancy_block_dim(T kernel) {
int _, block_dim;
if constexpr (std::is_same_v<T, CUfunction>) {
CHECK_CUDA_ERROR(
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
} else {
CHECK_CUDA_ERROR(
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
}
return block_dim;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -5,8 +5,9 @@
namespace mlx::core::cu { namespace mlx::core::cu {
Worker::Worker() Worker::Worker(Device& d)
: signal_stream_(device(mlx::core::Device::gpu)), : signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {} worker_(&Worker::thread_fn, this) {}
Worker::~Worker() { Worker::~Worker() {

View File

@@ -3,7 +3,6 @@
#pragma once #pragma once
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <condition_variable> #include <condition_variable>
#include <functional> #include <functional>
@@ -16,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream. // Run tasks in worker thread, synchronized with cuda stream.
class Worker { class Worker {
public: public:
Worker(); explicit Worker(Device& d);
~Worker(); ~Worker();
Worker(const Worker&) = delete; Worker(const Worker&) = delete;

View File

@@ -20,8 +20,8 @@ void copy_gpu_inplace(
int64_t o_offset, int64_t o_offset,
CopyType ctype, CopyType ctype,
const Stream& s, const Stream& s,
const std::optional<array>& dynamic_i_offset = std::nullopt, std::optional<array> dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt); std::optional<array> dynamic_o_offset = std::nullopt);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu(const array& src, array& out, CopyType ctype);

View File

@@ -80,6 +80,74 @@ void Depends::eval_gpu(
eval(inputs, outputs); eval(inputs, outputs);
} }
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* std::optional<array> dynamic_i_offset = */ std::move(in_offset),
/* std::optional<array> dynamic_o_offset = */ std::nullopt);
}
void DynamicSliceUpdate::eval_gpu(
const std::vector<array>& inputs,
array& out) {
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
auto& start_indices = inputs[2];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Copy or donate input to output
auto s = stream();
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
auto out_offset =
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* std::optional<array> dynamic_i_offset = */ std::nullopt,
/* std::optional<array> dynamic_o_offset = */ std::move(out_offset));
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) { void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out); eval(inputs, out);

View File

@@ -27,4 +27,10 @@ void pad_gpu(
const Shape& low_pad_size, const Shape& low_pad_size,
const Stream& s); const Stream& s);
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -33,10 +33,11 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h) make_jit_source(indexing/scatter kernels/indexing/indexing.h)
make_jit_source(gather kernels/indexing.h) make_jit_source(indexing/gather kernels/indexing/indexing.h)
make_jit_source(gather_axis) make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
make_jit_source(scatter_axis) make_jit_source(indexing/gather_axis)
make_jit_source(indexing/scatter_axis)
make_jit_source(hadamard) make_jit_source(hadamard)
if(MLX_METAL_JIT) if(MLX_METAL_JIT)
@@ -77,7 +78,10 @@ if(MLX_METAL_JIT)
make_jit_source(steel/conv/kernels/steel_conv) make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h) kernels/steel/conv/loaders/loader_general.h)
make_jit_source(quantized)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(gemv_masked) make_jit_source(gemv_masked)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel // Prepare unfolding kernel
std::ostringstream kname; std::string kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" kname.reserve(32);
<< N; concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log}; /* const int swizzle_log = */ swizzle_log};
// Determine kernel // Determine kernel
std::ostringstream kname; std::string kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" kname.reserve(64);
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" concatenate(
<< (n_channel_specialization ? std::to_string(n_channel_specialization) kname,
: "l") "implicit_gemm_conv_2d_",
<< "_filter_" << (small_filter ? 's' : 'l'); type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel( auto kernel = get_steel_conv_kernel(
d, d,
kname.str(), kname,
out, out,
bm, bm,
bn, bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{ {
int bc = 32; int bc = 32;
int bo = 4; int bo = 4;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32; int bc = 32;
int wm = 2; int wm = 2;
int wn = 2; int wn = 2;
std::ostringstream kname; std::string kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" kname.reserve(32);
<< bc; concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<2>& conv_params) { const MLXConvParams<2>& conv_params) {
std::ostringstream kname; std::string base_name;
kname << "depthwise_conv_2d_" << type_to_name(out); base_name.reserve(32);
std::string base_name = kname.str(); concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N; const int N = conv_params.N;
const int ker_h = conv_params.wS[0]; const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
}; };
// clang-format off // clang-format off
kname << "_ker_h_" << ker_h std::string hash_name;
<< "_ker_w_" << ker_w hash_name.reserve(64);
<< "_str_h_" << str_h concatenate(
<< "_str_w_" << str_w hash_name,
<< "_tgp_h_" << th base_name,
<< "_tgp_w_" << tw "_ker_h_", ker_h,
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on "_ker_w_", ker_w,
"_str_h_", str_h,
std::string hash_name = kname.str(); "_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts); auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
} }
} }
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu( void conv_1D_gpu(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1; bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2); int C = in.shape(2);
int O = wt.shape(0); int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups; // Fast path for fully separable 1D convolution
const int O_per_group = wt.shape(0) / groups; if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv // Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -20,8 +20,8 @@ void copy_gpu_inplace(
int64_t out_offset, int64_t out_offset,
CopyType ctype, CopyType ctype,
const Stream& s, const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */, std::optional<array> dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) { std::optional<array> dynamic_o_offset /* = std::nullopt */) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }

View File

@@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
} }
MTL::Library* load_default_library(MTL::Device* device) { MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4]; NS::Error* error[5];
MTL::Library* lib; MTL::Library* lib;
// First try the colocated mlx.metallib // First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
@@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
return lib; return lib;
} }
// Try lo load resources from Framework resources if SwiftPM wrapped as a
// dynamic framework.
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path // Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
if (!lib) { if (!lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "Failed to load the default metallib. "; msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) { for (int i = 0; i < 5; i++) {
if (error[i] != nullptr) { if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " "; msg << error[i]->localizedDescription()->utf8String() << " ";
} }
@@ -464,6 +471,10 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index); auto& stream = get_stream_(index);
if (stream.encoder == nullptr) { if (stream.encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream); stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence()); stream.fence = std::make_shared<Fence>(device_->newFence());
} }

View File

@@ -52,8 +52,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t slice_size = 1;
size_t ndim = src.ndim(); for (auto s : slice_sizes_) {
slice_size *= s;
}
bool large_index = nidx && inputs[1].size() > INT32_MAX; bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX; bool large_src = src.size() > INT32_MAX;
@@ -61,6 +63,55 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
bool large = large_index || large_src || large_out; bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&
inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {
int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;
auto& indices = inputs[1];
std::string kernel_name = fmt::format(
"gather_front{0}_{1}_{2}_{3}",
type_to_name(out),
idx_type_name,
large ? "int64_t" : "int",
work_per_thread);
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather_front();
kernel_source += get_template_definition(
kernel_name,
"gather_front",
get_type_string(out.dtype()),
get_type_string(indices.dtype()),
large ? "int64_t" : "int",
work_per_thread);
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
size_t dim_y = indices.size();
auto group_dims = get_block_dims(dim_x, dim_y, 1);
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
compute_encoder.set_input_array(src, 0);
compute_encoder.set_input_array(indices, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(slice_size, 3);
compute_encoder.set_bytes(src.shape(0), 4);
compute_encoder.dispatch_threads(grid_dims, group_dims);
return;
}
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::string kernel_name = fmt::format( std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}", "gather{0}{1}_{2}_{3}_{4}",
type_to_name(out), type_to_name(out),
@@ -96,11 +147,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
// Launch 3D grid of threads // Launch 3D grid of threads
// First two dimensions for the indices, the last one for the slice // First two dimensions for the indices, the last one for the slice
size_t dim0 = 1; size_t dim0 = 1;
@@ -332,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
if (upd_ndim == 0) { if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't complain
int shape_ = 0; int shape_ = 0;
int64_t stride_ = 0; int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3); compute_encoder.set_bytes(shape_, 3);
@@ -347,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set output info // Set output info
size_t out_ndim = out.ndim(); size_t out_ndim = out.ndim();
if (out_ndim == 0) { if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't complain
int shape_ = 0; int shape_ = 0;
int64_t stride_ = 0; int64_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7); compute_encoder.set_bytes(shape_, 7);

View File

@@ -19,9 +19,12 @@ const char* binary_two();
const char* copy(); const char* copy();
const char* fft(); const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* gather_front();
const char* hadamard(); const char* hadamard();
const char* logsumexp(); const char* logsumexp();
const char* quantized_utils();
const char* quantized(); const char* quantized();
const char* fp4_quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
const char* scatter_axis(); const char* scatter_axis();

View File

@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
auto t_str = get_type_string(type); auto t_str = get_type_string(type);
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary()); concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
{"v2", "ternary_v2"},
{"g1large", "ternary_g_nd1"}, {"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"}, {"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"}, {"g3large", "ternary_g_nd3"},
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
kernel_source += kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op); get_template_definition(name + "_" + lib_name, func, t_str, op);
} }
kernel_source += get_template_definition(
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
kernel_source += get_template_definition(
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
kernel_source += get_template_definition(
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
if (get_work_per_thread(type) > 1) { if (get_work_per_thread(type) > 1) {
kernel_source += kernel_source += get_template_definition(
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op); "vn_" + lib_name, "ternary_v", t_str, op, false, false);
kernel_source += get_template_definition(
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
kernel_source += get_template_definition(
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
} }
kernel_source += kernel_source += get_template_definition(
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1); "v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
kernel_source += get_template_definition(
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
kernel_source += get_template_definition(
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
kernel_source += get_template_definition( kernel_source += get_template_definition(
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition( kernel_source += get_template_definition(
@@ -804,13 +819,19 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def) { const std::string& template_def,
const std::string& mode) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized() concatenate(
<< template_def; kernel_source,
return kernel_source.str(); metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
template_def);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -823,6 +844,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x, const array& x,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
int bm, int bm,
int bn, int bn,
int bk, int bk,
@@ -832,14 +854,15 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source; std::string kernel_source;
concatenate(
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
concatenate( concatenate(
kernel_source, kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(), metal::quantized(),
get_template_definition( get_template_definition(
lib_name, lib_name,
"gather_qmm_rhs", mode + "_gather_qmm_rhs",
get_type_string(x.dtype()), get_type_string(x.dtype()),
group_size, group_size,
bits, bits,
@@ -849,6 +872,23 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
wm, wm,
wn, wn,
transpose)); transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source; return kernel_source;
}); });
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def); const std::string& template_def,
const std::string& mode);
MTL::ComputePipelineState* get_gather_qmm_kernel( MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d, metal::Device& d,
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x, const array& x,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
int bm, int bm,
int bn, int bn,
int bk, int bk,

View File

@@ -108,7 +108,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_all.h reduction/reduce_all.h
reduction/reduce_col.h reduction/reduce_col.h
reduction/reduce_row.h) reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h) build_kernel(logsumexp logsumexp.h)

View File

@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
constexpr complex64_t operator+(complex64_t a, complex64_t b) { constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag}; return {a.real + b.real, a.imag + b.imag};
} }
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr threadgroup complex64_t& operator+=(
threadgroup complex64_t& a,
complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
a.real += b.real;
a.imag += b.imag;
return a;
}
constexpr complex64_t operator+(float a, complex64_t b) { constexpr complex64_t operator+(float a, complex64_t b) {
return {a + b.real, b.imag}; return {a + b.real, b.imag};
} }

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half); instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t); instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels /// Winograd kernels
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
uint8_t, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
aligned, \
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
D, \
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
uint8_t, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -15,6 +15,15 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
template <typename U>
struct DefaultAccT {
using type = float;
};
template <>
struct DefaultAccT<complex64_t> {
using type = complex64_t;
};
template < template <
typename T, typename T,
const int BM, /* Threadgroup rows (in simdgroups) */ const int BM, /* Threadgroup rows (in simdgroups) */
@@ -24,8 +33,10 @@ template <
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float> typename AccT = typename DefaultAccT<T>::type>
struct GEMVKernel { struct GEMVKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int threadsN = BN * SN;
@@ -35,8 +46,8 @@ struct GEMVKernel {
static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert( static_assert(
SN == 8 || SN == 16 || SN == 32, SN == 4 || SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32"); "gemv block must have a width of 4, 8, 16, or 32");
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups // into blocks of (blockM, blockN) divided among threadgroups
@@ -246,8 +257,10 @@ template <
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
typename AccT = float> typename AccT = typename DefaultAccT<T>::type>
struct GEMVTKernel { struct GEMVTKernel {
using acc_type = AccT;
MLX_MTL_CONST int threadsM = BM * SM; MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN; MLX_MTL_CONST int threadsN = BN * SN;
@@ -453,7 +466,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>; using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets // Update batch offsets
@@ -511,21 +524,26 @@ template <
axpby) axpby)
// clang-format off // clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
// clang-format off // clang-format off
#define instantiate_gemv_blocks(name, itype) \ #define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \ instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \ instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t); instantiate_gemv_blocks(bfloat16, bfloat16_t);
instantiate_gemv_blocks(complex64, complex64_t);
template < template <
typename T, typename T,
@@ -561,7 +579,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>; using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec; uint32_t indx_vec;
@@ -632,6 +650,7 @@ template <
instantiate_gemv_bs_blocks(float32, float); instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half); instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_bs_blocks(complex64, complex64_t);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication /// Vector matrix multiplication
@@ -668,7 +687,7 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>; using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
threadgroup float tgp_memory threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
// Update batch offsets // Update batch offsets
@@ -734,7 +753,8 @@ template <
// clang-format off // clang-format off
instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
template < template <
typename T, typename T,
@@ -769,8 +789,8 @@ template <
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>; using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
threadgroup float tgp_memory threadgroup typename gemv_kernel::acc_type tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
uint32_t indx_vec; uint32_t indx_vec;
@@ -844,4 +864,5 @@ template <
// clang-format off // clang-format off
instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half); instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/indexing.h" #include "mlx/backend/metal/kernels/indexing/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT> template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl( METAL_FUNC void gather_impl(

View File

@@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing/indexing.h"
template <typename T, typename IdxT, typename LocT, int N>
[[kernel]] void gather_front(
const device T* src,
const device IdxT* indices,
device T* out,
const constant int64_t& stride,
const constant int& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto idx = offset_neg_idx(indices[index.y], size);
LocT src_idx = static_cast<LocT>(stride) * idx;
LocT out_idx = static_cast<LocT>(stride) * index.y;
int s_idx = N * index.x;
for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
out[out_idx + s_idx] = src[src_idx + s_idx];
}
}

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/indexing.h" #include "mlx/backend/metal/kernels/indexing/indexing.h"
template < template <
typename T, typename T,

View File

@@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets(
} }
template <typename T, int group_size, int bits, int D, bool batched> template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad( [[kernel]] void affine_qmv_quad(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1486,7 +1486,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
} }
template <typename T, int group_size, int bits, bool batched> template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast( [[kernel]] void affine_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1538,7 +1538,7 @@ template <typename T, int group_size, int bits, bool batched>
} }
template <typename T, const int group_size, const int bits, bool batched> template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv( [[kernel]] void affine_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1590,7 +1590,7 @@ template <typename T, const int group_size, const int bits, bool batched>
} }
template <typename T, const int group_size, const int bits, bool batched> template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm( [[kernel]] void affine_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1642,7 +1642,7 @@ template <typename T, const int group_size, const int bits, bool batched>
} }
template <typename T, const int group_size, const int bits, int split_k = 32> template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k( [[kernel]] void affine_qvm_split_k(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1706,7 +1706,7 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void qmm_t( [[kernel]] void affine_qmm_t(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1764,7 +1764,7 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void qmm_n( [[kernel]] void affine_qmm_n(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1817,7 +1817,7 @@ template <
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv_fast( [[kernel]] void affine_gather_qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1879,7 +1879,7 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv( [[kernel]] void affine_gather_qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -1941,7 +1941,7 @@ template <typename T, int group_size, int bits>
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void gather_qvm( [[kernel]] void affine_gather_qvm(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -2010,7 +2010,7 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void gather_qmm_t( [[kernel]] void affine_gather_qmm_t(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -2077,7 +2077,7 @@ template <
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void gather_qmm_n( [[kernel]] void affine_gather_qmm_n(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
@@ -2138,92 +2138,6 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template < template <
typename T, typename T,
int group_size, int group_size,
@@ -2234,7 +2148,7 @@ template <
int WM, int WM,
int WN, int WN,
bool transpose> bool transpose>
[[kernel]] void gather_qmm_rhs( [[kernel]] void affine_gather_qmm_rhs(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* scales [[buffer(2)]],

View File

@@ -3,6 +3,7 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/quantized.h" #include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \ #define instantiate_quantized(name, type, group_size, bits) \
@@ -79,40 +80,40 @@
instantiate_quantized_batched(name, type, group_size, bits, 0) instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \ #define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \ instantiate_quantized(affine_gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \ instantiate_quantized(affine_gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits) instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \ #define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0) instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \ #define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \ instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0) instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \ #define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \ #define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \ #define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \

View File

@@ -0,0 +1,90 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_simdgroup>
#include <metal_stdlib>
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}

View File

@@ -3,14 +3,19 @@
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
constant bool forward [[function_constant(1)]];
constant bool traditional [[function_constant(2)]];
constant bool hs_transpose [[function_constant(3)]];
template <typename T>
void rope_single_impl( void rope_single_impl(
const device T* in, const device T* in,
device T* out, device T* out,
constant const int& offset, constant const int& offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
uint2 pos, uint2 pos,
uint2 grid) { uint2 grid) {
float L = scale * static_cast<float>(offset); float L = scale * static_cast<float>(offset);
@@ -46,76 +51,85 @@ void rope_single_impl(
out[index_2] = static_cast<T>(rx2); out[index_2] = static_cast<T>(rx2);
} }
template <typename T, bool traditional, bool forward> template <typename T>
[[kernel]] void rope_single( [[kernel]] void rope_single(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x); float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base); float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>( rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
in, out, offset, inv_freq, scale, stride, pos, grid);
} }
template <typename T, bool traditional, bool forward> template <typename T>
[[kernel]] void rope_single_freqs( [[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, constant const int& offset,
constant const float& scale, constant const float& scale,
constant const size_t& stride, constant const int64_t& stride,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]], uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) { uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>( rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
in, out, offset, inv_freq, scale, stride, pos, grid);
} }
template <typename T, bool traditional, bool forward, int N = 4> template <typename T, typename IdxT, int N = 4>
void rope_impl( void rope_impl(
const device T* in, const device T* in,
device T* out, device T* out,
constant const int& offset, const device int* offset,
const float inv_freq, const float inv_freq,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos, uint3 pos,
uint3 grid) { uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset); auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta // Compute costheta, sintheta
float theta = L * inv_freq; float theta = L * inv_freq;
float costheta = metal::fast::cos(theta); float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta); float sintheta = metal::fast::sin(theta);
// Compute the input and output indices // Compute the input and output indices
size_t in_index_1, in_index_2; IdxT in_index_1;
size_t out_index_1, out_index_2; if (hs_transpose) {
if (traditional) { IdxT batch_stride = grid.y * IdxT(strides[1]);
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 = in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else { } else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
} }
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { IdxT in_index_2;
IdxT out_index_1 =
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
IdxT out_index_2;
if (traditional) {
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + 1;
in_index_1 += 2 * pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + IdxT(strides[2]);
} else {
out_index_1 += pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
in_index_1 += pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -130,28 +144,29 @@ void rope_impl(
} }
out[out_index_1] = static_cast<T>(rx1); out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2); out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0]; in_index_1 += IdxT(strides[0]);
in_index_2 += strides[0]; in_index_2 += IdxT(strides[0]);
out_index_1 += out_strides[0]; out_index_1 += IdxT(out_strides[0]);
out_index_2 += out_strides[0]; out_index_2 += IdxT(out_strides[0]);
} }
} }
template <typename T, bool traditional, bool forward, int N = 4> template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope( [[kernel]] void rope(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]], constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x); float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base); float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>( rope_impl<T, IdxT, N>(
in, in,
out, out,
offset, offset,
@@ -159,26 +174,28 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
template <typename T, bool traditional, bool forward, int N = 4> template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope_freqs( [[kernel]] void rope_freqs(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const int& offset, const device int* offset,
constant const float& scale, constant const float& scale,
constant const size_t strides[3], constant const int64_t strides[3],
constant const size_t out_strides[3], constant const int64_t out_strides[3],
constant const size_t& n_batch, constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]], const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]], constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]], uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) { uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>( rope_impl<T, IdxT, N>(
in, in,
out, out,
offset, offset,
@@ -186,75 +203,27 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale, scale,
strides, strides,
out_strides, out_strides,
n_batch, offset_stride,
n_head,
pos, pos,
grid); grid);
} }
// clang-format off // clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \ #define instantiate_rope_g(name, type) \
template [[host_name("rope_" #name)]] [[kernel]] void \ instantiate_kernel("rope_" #name, rope, type, int32_t) \
rope<type, traditional, forward>( \ instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
const device type* in [[buffer(0)]], \ instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
device type* out [[buffer(1)]], \ instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \ #define instantiate_rope_s(name, type) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \ instantiate_kernel("rope_single_" #name, rope_single, type) \
rope_single<type, traditional, forward>( \ instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \ #define instantiate_rope(name, type) \
instantiate_rope_s(name, type, traditional, forward) \ instantiate_rope_s(name, type) \
instantiate_rope_g(name, type, traditional, forward) instantiate_rope_g(name, type)
instantiate_rope(traditional_float16, half, true, true) instantiate_rope(float16, half)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(bfloat16, bfloat16_t)
instantiate_rope(traditional_float32, float, true, true) instantiate_rope(float32, float) // clang-format on
instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@@ -306,6 +306,7 @@ template <
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM // Write simdgroup_sums to SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == simd_size - 1) { if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
} }
@@ -440,6 +441,7 @@ template <
} }
// Read in SM // Read in SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
read_into[i] = in[index_y * stride + i]; read_into[i] = in[index_y * stride + i];

View File

@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]]; constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]]; constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]]; constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
[[buffer(14), function_constant(has_mask)]], [[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN]; threadgroup U sum_exp_scores[BN];
// Adjust positions // Adjust positions
const int head_idx = tid.x; const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = q_batch_head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (bool_mask) { if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += q_batch_head_idx * mask_head_stride +
q_seq_idx * mask_q_seq_stride; simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
} }
if (float_mask) { if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + fmask += q_batch_head_idx * mask_head_stride +
q_seq_idx * mask_q_seq_stride; simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
} }
out += o_offset * V + simd_gid * v_per_thread; out += o_offset * V + simd_gid * v_per_thread;
@@ -83,8 +87,12 @@ template <typename T, int D, int V = D>
o[i] = 0; o[i] = 0;
} }
U max_score = -INFINITY; U max_score = Limits<U>::finite_min;
U sum_exp_score = 0; U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) { } else if (bool_mask) {
use_key = bmask[0]; use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@@ -107,7 +117,7 @@ template <typename T, int D, int V = D>
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) { if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0])); score += static_cast<U>(fmask[0]);
} }
// Update the accumulators // Update the accumulators
@@ -152,7 +162,8 @@ template <typename T, int D, int V = D>
for (int i = 0; i < v_per_thread; i++) { for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i]; outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
[[buffer(16), function_constant(has_mask)]], [[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]], [[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
// Adjust positions // Adjust positions
const int block_idx = tid.z; const int block_idx = tid.z;
const int head_idx = tid.x; const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y; const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset = const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = q_batch_head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread; queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) { if (bool_mask) {
bmask += head_idx * mask_head_stride + bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
if (float_mask) { if (float_mask) {
fmask += head_idx * mask_head_stride + fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
o[i] = 0; o[i] = 0;
} }
U max_score = -1e9; U max_score = Limits<U>::finite_min;
U sum_exp_score = 0; U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) { } else if (bool_mask) {
use_key = bmask[0]; use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@@ -268,6 +289,7 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) { if (float_mask) {
score += fmask[0]; score += fmask[0];
} }
@@ -379,7 +401,8 @@ template <typename T, int D>
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i]; outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }

View File

@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
b = w; b = w;
} }
template <typename T, typename = void>
struct Init {
static constexpr constant T v = Limits<T>::max;
};
template <typename T>
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
};
template <typename T> template <typename T>
struct LessThan { struct LessThan {
static constexpr constant T init = Limits<T>::max; static constexpr constant T init = Init<T>::v;
METAL_FUNC bool operator()(T a, T b) const {
METAL_FUNC bool operator()(T a, T b) { if constexpr (
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
bool an = isnan(a);
bool bn = isnan(b);
if (an | bn) {
return (!an) & bn;
}
}
return a < b; return a < b;
} }
}; };

View File

@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]]; constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]]; constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T> template <typename T>
struct TransformScale { struct TransformScale {
@@ -82,6 +83,7 @@ template <
const constant AttnParams* params [[buffer(4)]], const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -169,7 +171,7 @@ template <
VBlockLoader loader_v( VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089)); TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles // Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size constexpr short kFragSize = 8; // MMAFrag size
@@ -232,6 +234,14 @@ template <
max_score[i] = Limits<AccumType>::finite_min; max_score[i] = Limits<AccumType>::finite_min;
} }
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK; int kb_lim = params->NK;
if (do_causal) { if (do_causal) {
@@ -350,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] = Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else { } else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
} }
} }
} }

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" #include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"

View File

@@ -3,9 +3,8 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"

View File

@@ -23,10 +23,12 @@
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
// clang-format on // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \ #define instantiate_gemm( \

Some files were not shown because too many files have changed in this diff Show More