Compare commits

...

67 Commits

Author SHA1 Message Date
Angelos Katharopoulos
7a6adda1e6 Bump the version (#2627) 2025-09-26 15:15:28 -07:00
Angelos Katharopoulos
1a9f820af6 Compiled should not end in broadcast (#2622) 2025-09-26 13:36:09 -07:00
Awni Hannun
d4f4ff3c5e Allow None input to compiled functions (#2621)
* Allow None input to compiled functions

* Allow None input to compiled functions
2025-09-25 08:42:23 -07:00
Jagrit Digani
7c7e48dbd1 New tuning for small K gemv (#2620)
* New tuning for small K gemv
2025-09-23 12:28:35 -07:00
Daniel Yeh
fbbf3b9b3e Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:15 -07:00
Daniel Yeh
bf01ad9367 fix (#2613)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:04 -07:00
Cheng
ae438d05fa [CUDA] Recycle CUDA events (#2604)
* Make CudaEvent a CudaHandle

* Add caching for CudaEvent

* Make sure cuda events are destroyed at last

* Fix headers

* SharedEvent => AtomicEvent

* RawCudaEvent => CudaEventHandle, CudaEventWrapper => CopyableCudaEvent

* Remove unneeded asserts
2025-09-23 10:42:03 +09:00
Awni Hannun
711a645807 avoid producing NaN in attention (#2608) 2025-09-22 13:10:43 -07:00
Josh Bleecher Snyder
aa9d44b3d4 implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness

Updates #2599
2025-09-22 10:09:45 -07:00
Awni Hannun
ec2ab42888 Lower sorted QMM gather threshold (#2609) 2025-09-19 18:22:55 -07:00
Cheng
787c0d90cd Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache

* Do not check cache thrashing in tests
2025-09-19 09:12:14 +09:00
Oleksandr Bilous
e8b604a6a3 fix: library loading for swift dynamic frameworks (#2568) 2025-09-18 13:54:59 -07:00
Awni Hannun
50cc09887f expose depends (#2606) 2025-09-18 10:06:15 -07:00
Umberto Mignozzetti
3f730e77aa Update export function example for array input (#2598)
After changing the shape to conform (same shapes for all objects), the example works.
2025-09-16 14:38:05 -07:00
Awni Hannun
caecbe876a no copy batch rope (#2595) 2025-09-15 14:23:48 -07:00
Umberto Mignozzetti
8afb6d62f2 Fix typo in average_gradients function call (#2594) 2025-09-15 11:29:21 -07:00
Awni Hannun
6ccfa603cd fix metal scan (#2591) 2025-09-15 11:01:57 -07:00
Umberto Mignozzetti
36cad99a11 Refactor code examples to use 'gelu' (#2592)
Updated code examples to use 'gelu' directly instead of 'nn.gelu'.
2025-09-15 09:47:02 -07:00
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57 Add sdpa with sinks (#2558)
* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
2025-09-10 14:53:00 -07:00
Gökdeniz Gülmez
db5443e831 Adding Relu2 (#2582)
* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-09-10 07:24:30 -07:00
Cheng
52b8384d10 Fix flaky addmm tests (#2581) 2025-09-10 14:22:22 +09:00
Cheng
44cc5da4bc [CUDA] Fix alpha not respected when using bias epilogue (#2578) 2025-09-10 09:08:01 +09:00
Cheng
dde3682b69 [CUDA] Use GEMM with epilogue instead of AddMM (#2569) 2025-09-09 13:18:49 +09:00
Awni Hannun
17310d91a6 Add batch offsets for mx.fast.rope (#2564)
* implement batch rope for Metal

* cuda rope (#2576)
2025-09-08 17:35:07 -07:00
Cheng
b194d65a6a Some tweaks in cmake files (#2574)
* Do proper check of Metal lib

* Update doctest to get rid of cmake version hack
2025-09-09 08:27:18 +09:00
Cheng
a44b27f5f8 Fix a few ccache cache miss (#2573)
* Fix ccache cache miss

* Do not define _VERSION_ in python bindings
2025-09-09 07:41:05 +09:00
Awni Hannun
e5a33f2223 faster depthwise 1D conv (#2567) 2025-09-08 11:37:23 -07:00
Cheng
c1e3340b23 Set ccache size before building (#2570) 2025-09-07 09:00:31 +09:00
XXXXRT666
8f163a367d typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods

* Missing list_or_scalar
2025-09-04 09:08:11 -07:00
Manuel Villanueva
89a3df9014 Fixed several type annotations in the MLX stubs which degraded to Unknown/Any (#2560)
* Added scalar to stubs to fix Unkown Type Hint

### Proposed changes

Issue #2478 reports that several type annotations in the MLX stubs degrade to Unknown/Any in editors like VS Code with Pylance, due to missing imports (Union, Optional, Tuple) and an undefined scalar type alias.

This PR updates the stub generation patterns to:
	•	Add missing typing imports in mlx.core.__prefix__ so that Union, Optional, Tuple, etc. are always available.
	•	Define and export scalar: TypeAlias = Union[int, float, bool] in mlx.core.__suffix__ so that functions typed with Union[scalar, array] resolve correctly instead of falling back to Any.
	•	Update submodule stub prefixes (distributed, fast, linalg, metal, random) to import scalar alongside array, Device, and Stream, ensuring type checkers resolve the union consistently across modules.

With these changes, functions like mlx.add now display rich type signatures such as:

```
def add(
    a: scalar | array,
    b: scalar | array,
    stream: Stream | Device | None = None
) -> array
```

instead of degrading to Any.

### Checklist

	•	I have read the CONTRIBUTING document
	•	I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
	•	I have added tests that prove my fix is effective or that my feature works (n/a — stub generation only)
	•	I have updated the necessary documentation (if needed)

* add bool to patterns

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-03 12:52:08 -07:00
Krishi Saripalli
c5d2937aa5 chore: Update Docs With Slice Copy Example (#2559)
* chore: updated docs with slice copy example

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-09-02 22:07:02 -07:00
Awni Hannun
b61a65e313 fix copies in sdpa (#2563) 2025-09-02 11:00:36 -07:00
wrmsr
04cbb4191c Fix dequantize python sig (#2562) 2025-09-01 11:50:20 -07:00
Artur Antonov
c5460762e7 Fix AdamW weight_decay default value in docstring (#2557) 2025-08-31 21:29:30 -07:00
Awni Hannun
8ce49cd39e fix quantized vjp for mxfp4 (#2555) 2025-08-29 10:06:15 -07:00
Awni Hannun
9c68b50853 version bump (#2554) 2025-08-29 06:54:17 -07:00
Awni Hannun
111f1e71af Faster contiguous gather for indices in the first axis (#2552)
* faster contiguous gather for indices in the first axis

* work per thread > 1

* angelos suggestion for scales / biases
2025-08-28 21:26:30 -07:00
Awni Hannun
827003d568 fix METAL quantization in JIT (#2553) 2025-08-28 18:26:25 -07:00
Awni Hannun
d363a76aa4 Bump xcode in circle (#2551)
* bump xcode in circle

* bump xcode in circle

* bump xcode in circle
2025-08-28 13:13:34 -07:00
Awni Hannun
70560b6bd5 Add mode parameter for quantization (#2499)
* add mode parameter for quantization

* mxfp4 quantize/dequantize + start of optional biases

* mxfp4 works

* speedup

* cpu mxfp4

* fix

* fix test tol

* fix

* refactor

* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5 [CUDA] fix sort (#2550)
* [CUDA] fix sort

* fix test
2025-08-27 19:48:43 -07:00
Cheng
31c6f6e33f [CUDA] Use ConcurrentContext in concatenate_gpu (#2549) 2025-08-28 09:30:08 +09:00
Awni Hannun
584d48458e link with nccl (#2546) 2025-08-27 10:01:07 -07:00
Cheng
5cf984ca87 Separate cpu compilation cache by versions (#2548) 2025-08-27 11:25:15 +09:00
Cheng
a9bac3d9e5 Run CPP tests for CUDA build in CI (#2544) 2025-08-27 08:06:46 +09:00
Awni Hannun
5458d43247 add load with path tests (#2543) 2025-08-26 14:24:47 -07:00
Awni Hannun
a4dba65220 Enable cuda graph toggle (#2545)
* enable cuda graph toggle

* increase cache size
2025-08-26 12:50:38 -07:00
Awni Hannun
3dcb286baf Remove stream from average grads so it uses default (#2532)
* Remove stream from average grads so it uses default

* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9 [CUDA] Implement DynamicSlice/DynamicSliceUpdate (#2533)
* Move DynamicSlice to gpu/primitives

* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Awni Hannun
2ca75bb529 Remove nccl install in release (#2542) 2025-08-25 15:20:18 -07:00
Awni Hannun
db14e29a0b allow pathlib.Path to save/load functions (#2541) 2025-08-25 14:58:49 -07:00
Awni Hannun
d2f540f4e0 Use nccl header only when nccl is not present (#2539)
* use nccl header only when nccl is not present

* larger machine for cuda build
2025-08-25 14:17:25 -07:00
Cheng
333ffea273 [CUDA] Remove thrust in arange (#2535) 2025-08-24 16:22:36 +09:00
Cheng
f55b6f1f2f Enable COMPILE_WARNING_AS_ERROR for linux builds in CI (#2534) 2025-08-24 15:33:08 +09:00
Awni Hannun
30561229c7 Fix allocation bug in NCCL (#2530) 2025-08-22 14:39:43 -07:00
Awni Hannun
068a4612e9 nccl default for backend=any (#2528)
* nccl default for backend=any

* check num gpus + ensure row contiguous for all reduce

* comment
2025-08-22 12:24:27 -07:00
Andrey Portnoy
5722c147de [CUDA] Update calls to cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 (#2525)
* [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13

These functions' signatures changed in CUDA 13, so we differentiate
between CUDA 13 and preceding releases at compile time.

* Mention NVIDIA in ACKNOWLEDGMENTS.md
2025-08-21 19:57:20 -07:00
Cheng
f6819a1f26 Fix warning 186-D from nvcc (#2527) 2025-08-22 10:29:55 +09:00
Awni Hannun
f93f87c802 nccl dep + default for cuda (#2526) 2025-08-21 17:57:49 -07:00
Anastasiia Filippova
9392fc3f88 NCCL backend (#2476) 2025-08-21 11:56:15 -07:00
Awni Hannun
e843c4d8d5 fix power (#2523) 2025-08-21 06:46:01 -07:00
Angelos Katharopoulos
0c5fc63a36 Fix docs omission (#2524) 2025-08-20 17:56:06 -07:00
Angelos Katharopoulos
e397177f6e Custom cuda kernel (#2517) 2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe [CUDA] Fix stride of singleton dims before passing to cuDNN (#2521) 2025-08-21 08:55:26 +09:00
163 changed files with 8244 additions and 2370 deletions

View File

@@ -18,13 +18,14 @@ jobs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.9
brew install doxygen
python3.9 -m venv env
@@ -89,7 +90,8 @@ jobs:
command: |
uv venv
uv pip install cmake
uv pip install -e ".[dev]" -v
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
@@ -118,7 +120,7 @@ jobs:
parameters:
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
@@ -126,12 +128,13 @@ jobs:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
@@ -196,7 +199,7 @@ jobs:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e .
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
@@ -222,15 +225,20 @@ jobs:
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
@@ -238,12 +246,23 @@ jobs:
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --max-size 400MB
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
@@ -257,7 +276,7 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
default: "26.0.0"
build_env:
type: string
default: ""
@@ -266,7 +285,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
@@ -274,11 +293,15 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
@@ -288,19 +311,19 @@ jobs:
- run:
name: Install Python package
command: |
source env/bin/activate
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
conda activate env
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
@@ -310,7 +333,7 @@ jobs:
- run:
name: Build common package
command: |
source env/bin/activate
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
@@ -319,7 +342,7 @@ jobs:
- run:
name: Upload package
command: |
source env/bin/activate
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
@@ -392,7 +415,7 @@ jobs:
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
resource_class: xlarge
steps:
- checkout
- run:
@@ -439,7 +462,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
@@ -464,68 +487,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
@@ -567,7 +529,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
@@ -586,53 +548,7 @@ workflows:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
@@ -651,68 +567,7 @@ workflows:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:

View File

@@ -19,12 +19,17 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Organizations
MLX has received contributions from the following companies:
- NVIDIA Corporation & Affiliates
# Third-Party Software
MLX leverages several third-party software, listed here together with

View File

@@ -87,22 +87,21 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(METAL_LIB)
message(STATUS "Metal found ${METAL_LIB}")
else()
message(
FATAL_ERROR
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
endif()
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
@@ -111,7 +110,8 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -140,6 +140,12 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# With newer clang/gcc versions following libs are implicitly linked, but when
# building on old distributions they need to be explicitly listed.
target_link_libraries(mlx PRIVATE dl pthread)
endif()
if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.

54
cmake/FindNCCL.cmake Normal file
View File

@@ -0,0 +1,54 @@
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
# directories.
set(NCCL_ROOT_DIR
$ENV{NCCL_ROOT_DIR}
CACHE PATH "Folder contains NVIDIA NCCL")
find_path(
NCCL_INCLUDE_DIRS
NAMES nccl.h
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
${CUDA_TOOLKIT_ROOT_DIR}/include)
if($ENV{USE_STATIC_NCCL})
message(
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
set(NCCL_LIBNAME "libnccl_static.a")
else()
set(NCCL_LIBNAME "nccl")
endif()
find_library(
NCCL_LIBRARIES
NAMES ${NCCL_LIBNAME}
HINTS ${NCCL_LIB_DIR}
${NCCL_ROOT_DIR}
${NCCL_ROOT_DIR}/lib
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
${NCCL_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
NCCL_LIBRARIES)
if(NCCL_FOUND)
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
message(
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
file(
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
LIMIT_COUNT 1)
if(NCCL_MAJOR_VERSION_DEFINED)
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
endif()
message(
STATUS
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
endif()

View File

@@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``:
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
source=source,
ensure_row_contiguous=False,
)

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/cuda
python/memory_management
python/nn
python/optimizers

View File

@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
When building either the Python or C++ APIs make sure to pass the cmake flag

9
docs/src/python/cuda.rst Normal file
View File

@@ -0,0 +1,9 @@
CUDA
=====
.. currentmodule:: mlx.core.cuda
.. autosummary::
:toctree: _autosummary
is_available

View File

@@ -13,3 +13,4 @@ Fast
rope
scaled_dot_product_attention
metal_kernel
cuda_kernel

View File

@@ -27,6 +27,7 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
timeit(gelu, x)
timeit(mx.compile(gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss

View File

@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array(-1.0))
out, = imported_abs(mx.array([-1.0]))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -107,8 +107,20 @@ same array:
>>> a
array([1, 2, 0], dtype=int32)
Note that unlike NumPy, slicing an array creates a copy, not a view. So
mutating it does not mutate the original array:
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a[:]
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 3], dtype=int32)
Also unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell

View File

@@ -15,6 +15,7 @@
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
#include "mlx/version.h"
namespace mlx::core {
@@ -94,7 +95,11 @@ void* compile(
kernel_file_name = kernel_name;
}
auto output_dir = std::filesystem::temp_directory_path();
auto output_dir =
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
if (!std::filesystem::exists(output_dir)) {
std::filesystem::create_directories(output_dir);
}
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();

View File

@@ -1,7 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -13,6 +11,35 @@ namespace mlx::core {
namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
@@ -407,6 +434,231 @@ void _qmm_dispatch(
}
}
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
@@ -513,115 +765,198 @@ void _bs_qmm_dispatch(
}
}
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& lhs_indices = inputs[inputs.size() - 2];
auto& rhs_indices = inputs[inputs.size() - 1];
std::vector<array> temps;
auto& encoder = cpu::get_command_encoder(stream());
auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) {
&encoder](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy_cpu(arr, temps.back(), CopyType::General, s);
return temps.back();
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
copy_cpu(arr, arr_cpy, CopyType::General, s);
encoder.add_temporary(arr_cpy);
return arr_cpy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
}
template <typename T, typename U>
@@ -705,7 +1040,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
void fast::Quantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
@@ -764,7 +1099,7 @@ void fast::AffineQuantize::eval_cpu(
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
}
});
}

View File

@@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
static_assert(std::is_same_v<MaskT, bool>);
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
@@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
// Raising an integer to a negative power is undefined
if (any(exp < 0)) {
return 0;
}
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > 0, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -20,7 +20,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp

View File

@@ -30,8 +30,15 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = 0;
#else
int loc = 0;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) {

View File

@@ -6,23 +6,33 @@
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
namespace mlx::core {
namespace cu {
template <typename T>
struct Arange {
const T start;
const T step;
namespace cg = cooperative_groups;
__device__ T operator()(uint32_t i) const {
return start + i * step;
template <typename T, typename IdxT, int N_WRITES>
__global__ void arange(T* out, IdxT size, T start, T step) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_WRITES > size) {
for (IdxT i = index * N_WRITES; i < size; ++i) {
out[i] = start + i * step;
}
} else {
AlignedVector<T, N_WRITES> out_vec;
#pragma unroll
for (int i = 0; i < N_WRITES; ++i) {
out_vec[i] = start + (index * N_WRITES + i) * step;
}
store_vector<N_WRITES>(out, index, out_vec);
}
};
}
} // namespace cu
@@ -36,19 +46,23 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
encoder.set_output_array(out);
auto capture = encoder.capture_context();
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
thrust::transform(
cu::thrust_policy(encoder.stream()),
thrust::counting_iterator<uint32_t>(0),
thrust::counting_iterator<uint32_t>(out.data_size()),
thrust::device_pointer_cast(out.data<OutType>()),
cu::Arange<OutType>{
static_cast<OutType>(start_), static_cast<OutType>(step)});
constexpr int N_WRITES = 16 / sizeof(OutType);
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
encoder.add_kernel_node(
cu::arange<OutType, IdxT, N_WRITES>,
num_blocks,
block_dims,
0,
out.data<OutType>(),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
});
});
}

View File

@@ -267,7 +267,8 @@ void Compiled::eval_gpu(
}
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
return std::make_tuple(
false, std::move(builder.os), std::move(kernel_names));
});
// Collapse contiguous dims to route to a faster kernel if possible. Also

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}

View File

@@ -15,8 +15,8 @@ void copy_gpu_inplace(
int64_t offset_out,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_offset_in,
const std::optional<array>& dynamic_offset_out) {
std::optional<array> dynamic_offset_in,
std::optional<array> dynamic_offset_out) {
if (out.size() == 0) {
return;
}
@@ -44,6 +44,16 @@ void copy_gpu_inplace(
strides_vec[0]);
} else {
if (dynamic_offset_in || dynamic_offset_out) {
if (!dynamic_offset_in) {
dynamic_offset_in = array(0, int64);
encoder.add_temporary(*dynamic_offset_in);
}
if (!dynamic_offset_out) {
dynamic_offset_out = array(0, int64);
encoder.add_temporary(*dynamic_offset_out);
}
encoder.set_input_array(*dynamic_offset_in);
encoder.set_input_array(*dynamic_offset_out);
copy_general_dynamic(
encoder,
ctype,
@@ -54,8 +64,8 @@ void copy_gpu_inplace(
shape_collapsed,
strides_vec[0],
strides_vec[1],
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
*dynamic_offset_in,
*dynamic_offset_out);
} else {
copy_general(
encoder,

View File

@@ -23,6 +23,24 @@ inline cudnn_frontend::Tensor build_cudnn_tensor(
.build();
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
assert(shape.size() >= 3);
@@ -33,8 +51,9 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides));
}
auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(convert_vector<int64_t>(x.shape()), x.strides());
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
@@ -140,7 +159,7 @@ bool prepare_cudnn_plan(
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, x.strides());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
@@ -160,7 +179,8 @@ cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s = x.strides(0);
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);

View File

@@ -0,0 +1,379 @@
// Copyright © 2025 Apple Inc.
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast.h"
#include "mlx/fast_primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core::fast {
namespace {
constexpr const char* default_header = R"(
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
std::string template_arguments_hash(
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
if (template_args.empty()) {
return "";
}
std::string hash;
hash.reserve(512);
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
hash += fmt::format("_{}", std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
hash += (std::get<bool>(arg)) ? "_t" : "_f";
} else if (std::holds_alternative<Dtype>(arg)) {
hash += "_";
hash += get_type_string(std::get<Dtype>(arg));
}
}
return hash;
}
std::string build_kernel(
const std::string& func_name,
const std::string& header,
const std::string& source,
const std::vector<std::string>& input_names,
const std::vector<array>& inputs,
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
kernel_source += header;
kernel_source +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
kernel_source += "__global__ void ";
kernel_source += func_name;
kernel_source += "(\n";
// Add inputs
for (int i = 0; i < inputs.size(); ++i) {
const auto& name = input_names[i];
const auto& arr = inputs[i];
kernel_source += " const ";
kernel_source += dtype_to_cuda_type(arr.dtype());
kernel_source += "* ";
kernel_source += name;
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
}
}
}
// Add outputs
for (int i = 0; i < output_names.size(); ++i) {
const auto& name = output_names[i];
const auto& dtype = output_dtypes[i];
kernel_source += " ";
kernel_source += dtype_to_cuda_type(dtype);
kernel_source += "* ";
kernel_source += name;
if (i < output_names.size() - 1) {
kernel_source += ",\n";
} else {
kernel_source += ") {\n";
}
}
// Set compile time constants
if (!template_args.empty()) {
for (const auto& [name, arg] : template_args) {
if (std::holds_alternative<int>(arg)) {
kernel_source +=
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
} else if (std::holds_alternative<bool>(arg)) {
kernel_source += fmt::format(
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
} else {
kernel_source += fmt::format(
" using {} = {};\n",
name,
dtype_to_cuda_type(std::get<Dtype>(arg)));
}
}
kernel_source += "\n";
}
kernel_source += source;
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
return kernel_source;
}
} // namespace
CustomKernelFunction cuda_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::string& source,
const std::string& header,
bool ensure_row_contiguous,
int shared_memory) {
if (output_names.empty()) {
throw std::invalid_argument(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
return [=, shape_infos = std::move(shape_infos)](
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
const std::vector<std::pair<std::string, TemplateArg>>&
template_args = {},
std::optional<float> init_value = std::nullopt,
bool verbose = false,
StreamOrDevice s_ = {}) {
if (inputs.size() != input_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `inputs` to have size "
<< input_names.size() << " but got size " << inputs.size() << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
if (output_shapes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_shapes` to have size "
<< output_names.size() << " but got size " << output_shapes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
if (output_dtypes.size() != output_names.size()) {
std::ostringstream msg;
msg << "[custom_kernel] Expected `output_dtypes` to have size "
<< output_names.size() << " but got size " << output_dtypes.size()
<< "." << std::endl;
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
if (s.device != Device::gpu) {
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
}
std::string kernel_name =
"custom_kernel_" + name + template_arguments_hash(template_args);
std::string kernel_source = build_kernel(
kernel_name,
header,
source,
input_names,
inputs,
output_names,
output_dtypes,
template_args,
shape_infos);
if (verbose) {
std::cout << "Generated source code for `" << kernel_name
<< "`:" << std::endl
<< "```" << std::endl
<< kernel_source << std::endl
<< "```" << std::endl;
}
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
std::move(kernel_name),
std::move(kernel_source),
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
std::move(inputs));
};
}
std::vector<array> precompiled_cuda_kernel(
const std::string& name,
const std::string& compiled_source,
const std::vector<array>& inputs,
const std::vector<Shape>& output_shapes,
const std::vector<Dtype>& output_dtypes,
const std::vector<ScalarArg>& scalars,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
std::make_shared<CustomKernel>(
to_stream(s),
name,
compiled_source,
grid,
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value,
scalars,
true,
shared_memory),
inputs);
}
void CustomKernel::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
std::vector<array> copies;
// Allocate and initialize the output arrays
for (auto& out : outputs) {
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
// Create the input arrays and copy if needed
auto check_input = [&copies, &s, this](const array& x) -> const array {
bool no_copy = x.flags().row_contiguous;
if (!ensure_row_contiguous_ || no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
// Compile the custom kernel
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
name_,
[&]() {
return std::make_tuple(
is_precompiled_, source_, std::vector{kernel_name});
},
false);
// Make the arguments
cu::KernelArgs args;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
args.append<int32_t>(in.ndim());
}
}
for (auto& out : outputs) {
args.append(out);
}
for (auto& s : scalar_arguments_) {
if (std::holds_alternative<bool>(s)) {
args.append(std::get<bool>(s));
} else if (std::holds_alternative<int>(s)) {
args.append(std::get<int>(s));
} else if (std::holds_alternative<float>(s)) {
args.append(std::get<float>(s));
}
}
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
for (const auto& t : copies) {
encoder.add_temporary(t);
}
auto kernel =
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
if (smem > 0 && smem > 48000) {
cuFuncSetAttribute(
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
}
});
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
}
} // namespace mlx::core::fast

View File

@@ -27,11 +27,11 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
}
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
}();
return cache_size;
return use_graphs;
}
} // namespace
@@ -86,11 +86,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
enc.device().make_current();
if (!use_cuda_graphs()) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) {
return;
}
graph.end_capture(enc.stream());
if (discard) {
return;
@@ -105,6 +112,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
enc.in_concurrent_ = false;
if (!use_cuda_graphs()) {
return;
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
@@ -186,18 +196,25 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
}
void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
void CommandEncoder::set_output_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
active_outputs_.push_back(id);
@@ -215,6 +232,11 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream()));
return;
}
cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDim = grid_dim;
@@ -230,6 +252,22 @@ void CommandEncoder::add_kernel_node(
dim3 block_dim,
uint32_t smem_bytes,
void** params) {
if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cuLaunchKernel(
func,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z,
smem_bytes,
stream(),
params,
nullptr));
return;
}
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func;
kernel_params.gridDimX = grid_dim.x;
@@ -256,6 +294,13 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) {
CudaGraphExec graph_exec;
graph_exec.instantiate(child);
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
return;
}
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
@@ -269,7 +314,13 @@ void CommandEncoder::commit() {
if (node_count_ > 0) {
if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
graph_,
from_nodes_.data(),
to_nodes_.data(),
#if CUDART_VERSION >= 13000
nullptr, // edgeData
#endif // CUDART_VERSION >= 13000
from_nodes_.size()));
}
graph_key_ += ".";

View File

@@ -76,9 +76,6 @@ class CommandEncoder {
uint32_t smem_bytes,
void** params);
// Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
void add_graph_node(cudaGraph_t child);
void add_temporary(const array& arr) {
@@ -101,6 +98,9 @@ class CommandEncoder {
void synchronize();
private:
void add_kernel_node(const cudaKernelNodeParams& params);
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
struct GraphNode {
cudaGraphNode_t node;
// K = kernel

View File

@@ -204,6 +204,12 @@ struct Power {
__device__ T operator()(T base, T exp) {
if constexpr (cuda::std::is_integral_v<T>) {
T res = 1;
// Raising an integer to a negative power is undefined
if constexpr (cuda::std::is_signed_v<T>) {
if (exp < 0) {
return 0;
}
}
while (exp) {
if (exp & 1) {
res *= base;

View File

@@ -6,7 +6,6 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <thrust/iterator/transform_iterator.h>
namespace mlx::core::cu {
@@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) {
return CastOp<SrcT, DstT>{}(x);
}
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
if constexpr (std::is_same_v<SrcT, DstT>) {
return it;
} else {
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
}
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,56 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/distributed/primitives.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core::distributed {
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto set_input_output =
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s);
return {out, out};
} else if (in.is_donatable()) {
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(allocator::malloc(out.nbytes()));
return {in, out};
}
};
auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input);
encoder.set_output_array(output);
auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), input, output, s);
break;
case Max:
distributed::detail::all_max(group(), input, output, s);
break;
case Min:
distributed::detail::all_min(group(), input, output, s);
break;
default:
throw std::runtime_error(
"Only all reduce sum, max, and min are supported.");
}
}
} // namespace mlx::core::distributed

View File

@@ -15,8 +15,9 @@ bool is_available() {
}
void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last.
cudaFree(nullptr);
// Force initalization of CUDA by creating an event, so the CUDA runtime and
// our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}

View File

@@ -3,10 +3,12 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h"
#include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
@@ -17,104 +19,141 @@ namespace cu {
// CudaEvent implementations
///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII.
class CudaEventHandle {
public:
CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
namespace {
// Manage cached cudaEvent_t objects.
struct CudaEventPool {
static CudaEventHandle create(int flags) {
auto& cache = cache_for(flags);
if (cache.empty()) {
return CudaEventHandle(flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
}
~CudaEventHandle() {
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
static void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event));
}
CudaEventHandle(const CudaEventHandle&) = delete;
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
operator cudaEvent_t() const {
return event_;
static std::vector<CudaEventHandle>& cache_for(int flags) {
static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
}
private:
cudaEvent_t event_;
};
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
} // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
cudaEventSynchronize(event_);
}
void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) {
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
cudaStreamWaitEvent(stream, event_);
}
void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
cudaEventRecord(event_, stream);
}
bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess;
return cudaEventQuery(event_) == cudaSuccess;
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
: event_(std::make_shared<CudaEvent>(
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations
// AtomicEvent implementations
///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
uint64_t current;
while ((current = ac->load()) < value) {
ac->wait(current);
}
}
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
ac->store(value);
ac->notify_all();
}
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value);
}
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value);
}
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
AtomicEvent::AtomicEvent() {
buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr);
@@ -123,17 +162,17 @@ SharedEvent::SharedEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
}
void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(to_atomic(buf_), value);
void AtomicEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait");
event_wait(atomic(), value);
}
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
}
void SharedEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
void AtomicEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else {
@@ -144,17 +183,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
}
}
void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(to_atomic(buf_), value);
void AtomicEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal");
event_signal(atomic(), value);
}
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
}
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
void AtomicEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
@@ -168,14 +207,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
}
}
bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return to_atomic(buf_)->load() >= value;
bool AtomicEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
return atomic()->load() >= value;
}
uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value");
return to_atomic(buf_)->load();
uint64_t AtomicEvent::value() const {
nvtx3::scoped_range r("cu::AtomicEvent::value");
return atomic()->load();
}
} // namespace cu
@@ -188,14 +227,14 @@ namespace {
struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases:
// to fallback to AtomicEvent in following cases:
// 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared;
std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::AtomicEvent> atomic;
bool is_created() const {
return cuda || shared;
return cuda || atomic;
}
void ensure_created(Stream s, uint64_t signal_value) {
@@ -203,10 +242,10 @@ struct EventImpl {
return;
}
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent");
shared = std::make_unique<cu::SharedEvent>();
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
} else {
cuda = std::make_unique<cu::CudaEvent>();
cuda = std::make_unique<cu::CopyableCudaEvent>();
}
}
};
@@ -225,7 +264,7 @@ void Event::wait() {
assert(value() == 1);
event->cuda->wait();
} else {
event->shared->wait(value());
event->atomic->wait(value());
}
}
@@ -236,7 +275,7 @@ void Event::wait(Stream s) {
assert(value() == 1);
event->cuda->wait(s);
} else {
event->shared->wait(s, value());
event->atomic->wait(s, value());
}
}
@@ -247,7 +286,7 @@ void Event::signal(Stream s) {
assert(value() == 1);
event->cuda->record(s);
} else {
event->shared->signal(s, value());
event->atomic->signal(s, value());
}
}
@@ -258,9 +297,9 @@ bool Event::is_signaled() const {
}
if (event->cuda) {
assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed();
return event->cuda->is_signaled();
} else {
return event->shared->is_signaled(value());
return event->atomic->is_signaled(value());
}
}

View File

@@ -3,49 +3,54 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <memory>
namespace mlx::core::cu {
class CudaEventHandle;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
int flags;
};
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
CudaEvent();
explicit CudaEvent(int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
void wait();
void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called.
bool completed() const;
bool recorded() const {
return recorded_;
}
private:
bool recorded_{false};
std::shared_ptr<CudaEventHandle> event_;
CudaEventHandle event_;
};
// Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible.
class SharedEvent {
class AtomicEvent {
public:
using Atomic = cuda::atomic<uint64_t>;
SharedEvent();
AtomicEvent();
void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value);
@@ -57,7 +62,11 @@ class SharedEvent {
uint64_t value() const;
private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
Atomic* atomic() const {
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
};
} // namespace mlx::core::cu

View File

@@ -7,7 +7,7 @@ namespace mlx::core {
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
cu::AtomicEvent event;
};
Fence::Fence(Stream s) {

View File

@@ -85,10 +85,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
int32_t batch_count,
int64_t batch_stride) {
cublasLtMatrixLayout_t desc;
if (transposed) {
std::swap(rows, cols);
}
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
if (batch_count > 1) {
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
@@ -138,25 +138,34 @@ CublasGemm::CublasGemm(
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));
cublasOperation_t op = CUBLAS_OP_N;
// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
// cublasLt return the row-major result, which works because:
// - the data of a matrix in row-major layout is identical to its transpose in
// column-major layout
// - C^T = (A @ B)^T = B^T @ A^T
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op,
&a_op,
sizeof(cublasOperation_t)));
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSB,
&op,
&b_op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cublas_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
b_desc_ = create_matrix_layout(
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
out_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
}
CublasGemm::CublasGemm(
@@ -191,7 +200,7 @@ CublasGemm::CublasGemm(
b_batch_stride) {
auto type = dtype_to_cublas_type(dtype);
c_desc_ = create_matrix_layout(
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
}
CublasGemm::~CublasGemm() {
@@ -213,14 +222,30 @@ void CublasGemm::set_out(
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
out_desc_ = create_matrix_layout(
dtype_to_cublas_type(dtype),
rows,
cols,
rows,
transposed,
ld,
batch_count,
batch_stride);
}
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
auto* bias_ptr = bias.data<void>();
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
}
void CublasGemm::run(
cu::CommandEncoder& encoder,
array& out,
@@ -228,11 +253,19 @@ void CublasGemm::run(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
if (batch_count / batch_shape.back() > 1) {
run_batched(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
encoder,
out,
a,
b,
batch_shape,
a_batch_strides,
b_batch_strides,
alpha);
return;
}
@@ -240,7 +273,13 @@ void CublasGemm::run(
encoder.set_input_array(b);
encoder.set_output_array(out);
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
execute(
encoder,
out.data<void>(),
a.data<void>(),
b.data<void>(),
nullptr,
alpha);
}
void CublasGemm::run(
@@ -330,9 +369,9 @@ void CublasGemm::execute(
handle_,
matmul_desc_,
&alpha,
a,
b, // a and b are swapped
a_desc_,
b,
a,
b_desc_,
&beta,
c ? c : out,

View File

@@ -55,6 +55,8 @@ class CublasGemm {
int32_t batch_count,
int64_t batch_stride);
void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run(
cu::CommandEncoder& encoder,
array& out,
@@ -62,7 +64,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha = 1.0f);
void run(
cu::CommandEncoder& encoder,
@@ -85,7 +88,8 @@ class CublasGemm {
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides);
const Strides& b_batch_strides,
float alpha);
void run_batched(
cu::CommandEncoder& encoder,

View File

@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
a.data<int8_t>() + a.itemsize() * a_it.loc,
b.data<int8_t>() + b.itemsize() * b_it.loc,
nullptr);
nullptr,
alpha);
a_it.step();
b_it.step();
}

View File

@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
const array& b,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides) {
const Strides& b_batch_strides,
float alpha) {
int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count);
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
reinterpret_cast<void*>(out_pointers),
reinterpret_cast<void*>(a_pointers),
reinterpret_cast<void*>(b_pointers),
nullptr);
nullptr,
alpha);
}
void CublasGemm::run_batched(

View File

@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_gather, std::move(kernel_names));
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
});
cu::KernelArgs args;
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
large ? "int64_t" : "int32_t"));
}
}
return std::make_pair(jit_source_scatter, std::move(kernel_names));
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
});
cu::KernelArgs args;
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
args.append(axes_);
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
}
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
return std::make_tuple(
false, jit_source_gather_axis, std::move(kernel_names));
});
size_t idx_size_pre = 1;
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
}
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
return std::make_tuple(
false, jit_source_scatter_axis, std::move(kernel_names));
});
size_t idx_size_pre = 1;

View File

@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
return path.string();
}
// Finally check the environment variable.
path = std::getenv("MLX_CCCL_DIR");
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
path = env;
if (!path.empty() && std::filesystem::exists(path)) {
return path.string();
}
}
return std::string();
}();
@@ -101,8 +103,8 @@ const std::filesystem::path& ptx_cache_dir() {
bool read_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
@@ -117,15 +119,15 @@ bool read_cached_ptx(
if (!ptx_file.good()) {
return false;
}
ptx->resize(ptx_size);
ptx_file.read(ptx->data(), ptx_size);
ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
std::string line;
while (std::getline(txt_file, line)) {
auto tab = line.find('\t');
if (tab != std::string::npos) {
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
}
}
return true;
@@ -135,7 +137,7 @@ bool read_cached_ptx(
void write_cached_ptx(
const std::filesystem::path& cache_dir,
const std::string& module_name,
const std::vector<char>& ptx,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
const std::string& source_code) {
if (cache_dir.empty()) {
@@ -217,85 +219,85 @@ constexpr const char* g_headers[] = {
jit_source_utils,
};
} // namespace
JitModule::JitModule(
void compile(
Device& device,
const std::string& module_name,
const KernelBuilder& builder) {
// Check cache.
std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
// Create program.
auto [source_code, kernel_names] = builder();
nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog,
source_code.c_str(),
(module_name + ".cu").c_str(),
std::size(g_headers),
g_headers,
g_include_names));
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
&prog,
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size, 0);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
const std::string& source,
const std::vector<std::string>& kernel_names,
std::string& ptx,
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
// Create the program
nvrtcProgram prog;
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
&prog,
source.c_str(),
(module_name + ".cu").c_str(),
std::size(g_headers),
g_headers,
g_include_names));
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
&prog,
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
for (const auto& name : kernel_names) {
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
}
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {
cccl_include = fmt::format("--include-path={}", cccl_include);
args.push_back(cccl_include.c_str());
}
std::string cuda_include =
fmt::format("--include-path={}/include", cuda_home());
args.push_back(cuda_include.c_str());
nvrtcResult compile_result =
nvrtcCompileProgram(prog, args.size(), args.data());
if (compile_result != NVRTC_SUCCESS) {
size_t log_size;
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
std::vector<char> log(log_size + 1, 0);
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
throw std::runtime_error(
fmt::format("Failed to compile kernel: {}.", log.data()));
}
// Get mangled names of kernel names.
for (const auto& name : kernel_names) {
const char* mangled;
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
ptx_kernels.emplace_back(name, mangled);
}
// Get ptx data.
size_t ptx_size;
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
}
ptx.resize(ptx_size);
if (use_sass) {
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
} else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
}
}
void load_module(
const std::string& module_name,
const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) {
// Load module.
char jit_log[4089] = {};
CUjit_option options[] = {
@@ -312,21 +314,69 @@ JitModule::JitModule(
for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels_[name] = kernel;
kernels[name] = std::make_pair(kernel, false);
}
}
} // namespace
JitModule::JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder,
bool use_disk_cache) {
// Will hold the actual device executable source code and kernel names
std::string ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels;
// Try to load them from the file cache
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
auto [precompiled, source_code, kernel_names] = builder();
// Get the PTX or cubin
if (precompiled) {
ptx = std::move(source_code);
for (auto& name : kernel_names) {
ptx_kernels.emplace_back(name, name);
}
} else {
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
}
// If requested save them in the file cache for the next launch
if (use_disk_cache) {
write_cached_ptx(
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
}
}
// Load the module
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
}
JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_));
}
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name);
if (it == kernels_.end()) {
throw std::runtime_error(
fmt::format("There is no kernel named {}.", kernel_name));
}
return it->second;
// If it is the first time we run this kernel then configure it. Do it only
// once!
if (!it->second.second) {
if (configure_kernel) {
configure_kernel(it->second.first);
}
it->second.second = true;
}
return it->second.first;
}
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
@@ -337,11 +387,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder) {
const KernelBuilder& builder,
bool cache) {
auto& map = get_jit_module_cache();
auto it = map.find(name);
if (it == map.end()) {
it = map.try_emplace(name, cu::device(device), name, builder).first;
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
}
return it->second;
}

View File

@@ -19,7 +19,8 @@ namespace mlx::core::cu {
class Device;
using KernelBuilderResult = std::pair<
using KernelBuilderResult = std::tuple<
/* precompiled */ bool,
/* source code */ std::string,
/* kernel names */ std::vector<std::string>>;
using KernelBuilder = std::function<KernelBuilderResult()>;
@@ -45,6 +46,11 @@ struct KernelArgs {
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
}
template <typename T>
void append(const std::vector<T>& vec) {
append(SmallVector<T>(vec.begin(), vec.end()));
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(SmallVector<T> vec) {
@@ -63,14 +69,16 @@ struct KernelArgs {
private:
std::vector<void*> args_;
// The cuLaunchKernel API requires passing pointers to arguments so store
// temporary values untill kernel is launched.
// The cuGraphAddKernelNode API requires passing pointers to arguments so
// store temporary values until the node is created.
using Arg = std::variant<
std::monostate,
CUdeviceptr,
bool,
int32_t,
uint32_t,
int64_t,
float,
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
@@ -82,16 +90,19 @@ class JitModule {
JitModule(
Device& device,
const std::string& module_name,
const KernelBuilder& builder);
const KernelBuilder& builder,
bool cache);
~JitModule();
JitModule(const JitModule&) = delete;
JitModule& operator=(const JitModule&) = delete;
CUfunction get_kernel(const std::string& kernel_name);
CUfunction get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private:
CUmodule module_{nullptr};
std::unordered_map<std::string, CUfunction> kernels_;
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_;
};
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
@@ -99,6 +110,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
JitModule& get_jit_module(
const mlx::core::Device& device,
const std::string& name,
const KernelBuilder& builder);
const KernelBuilder& builder,
bool use_disk_cache = true);
} // namespace mlx::core::cu

View File

@@ -2,11 +2,15 @@
#pragma once
#include "mlx/utils.h"
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
#include <fmt/format.h>
namespace mlx::core {
template <
@@ -27,6 +31,14 @@ class LRUCache {
}
}
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const {
return map_.size();
}
@@ -76,6 +88,14 @@ class LRUCache {
return {it->second, false};
}
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
@@ -106,6 +126,9 @@ class LRUCache {
}
}
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_;
map_type map_;
size_t capacity_;

View File

@@ -11,6 +11,7 @@
#include <numeric>
namespace mlx::core {
namespace {
std::tuple<bool, int64_t, array>
@@ -28,6 +29,76 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
}
}
void gemm_and_bias(
cu::CommandEncoder& encoder,
int M,
int N,
int K,
bool a_transposed,
int64_t lda,
bool b_transposed,
int64_t ldb,
array& out,
const array& a,
const array& b,
const std::optional<array>& bias = std::nullopt,
float alpha = 1.0f) {
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
// Use gemmv when possible
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
// Invoke cublasLt
CublasGemm gemm(
encoder.device(),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
if (bias) {
gemm.set_bias(encoder, *bias);
}
gemm.run(
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
}
} // namespace
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -48,9 +119,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
int K = a_pre.shape(-1);
@@ -60,58 +128,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
auto batch_count = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
b_batch_strides.back() == 0) {
M *= batch_shape.back();
batch_count = 1;
a_batch_strides = {0};
b_batch_strides = {0};
batch_shape = {1};
}
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
cu::gemv(
a,
b,
out,
M,
N,
K,
batch_count,
batch_shape,
a_batch_strides,
b_batch_strides,
encoder);
return;
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
CublasGemm gemm(
cu::device(s.device),
a.dtype(),
a_transposed,
M,
K,
lda,
b_transposed,
K,
N,
ldb,
batch_shape.back(),
a_batch_strides.back(),
b_batch_strides.back());
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
gemm_and_bias(
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -136,6 +154,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
/////////////////////////////////////////////////////////////////////////////
// Dispatch to GEMM with epilogue or AddMM
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
out.set_data(allocator::malloc(out.nbytes()));
gemm_and_bias(
encoder,
M,
N,
K,
a_transposed,
lda,
b_transposed,
ldb,
out,
a,
b,
c,
alpha_);
return;
}
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
@@ -177,7 +217,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
/////////////////////////////////////////////////////////////////////////////
// Invoke cublasLt
// Invoke cublasLt with AddMM settings
CublasGemm gemm(
cu::device(s.device),

View File

@@ -1,11 +1,47 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cuda.h"
#include "mlx/fast.h"
namespace mlx::core::cu {
namespace mlx::core {
namespace cu {
bool is_available() {
return false;
}
} // namespace mlx::core::cu
} // namespace cu
namespace fast {
CustomKernelFunction cuda_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool,
int) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
std::vector<array> precompiled_cuda_kernel(
const std::string&,
const std::string&,
const std::vector<array>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
const std::vector<ScalarArg>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,
int shared_memory,
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice) {
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
}
} // namespace fast
} // namespace mlx::core

View File

@@ -24,8 +24,6 @@ namespace mlx::core {
}
NO_GPU(BlockMaskedMM)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
@@ -41,12 +39,7 @@ NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace fast {
NO_GPU_MULTI(CustomKernel)
} // namespace fast
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)

View File

@@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
} // namespace
void fast::AffineQuantize::eval_gpu(
void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
nvtx3::scoped_range r("Quantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);

View File

@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
int offset,
const int* offset,
float inv_freq,
float scale,
const cuda::std::array<int64_t, 3> strides,
const cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 pos,
uint3 dims) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
@@ -123,20 +129,19 @@ __device__ void rope_impl(
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
mat_idx * out_strides[0];
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -167,7 +172,8 @@ __global__ void rope(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims) {
uint3 pos = make_uint3(
blockIdx.x * blockDim.x + threadIdx.x,
@@ -182,12 +188,13 @@ __global__ void rope(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
float base,
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
int64_t n_batch,
int64_t offset_stride,
int n_head,
uint3 dims,
int64_t freq_stride) {
uint3 pos = make_uint3(
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
rope_impl<T, traditional, forward>(
in,
out,
*offset,
offset,
inv_freq,
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
dims);
}
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
auto& offset = inputs[1];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
// We apply rope to less that the whole vector so copy to output and then
// apply in-place.
if (dims_ < in.shape(-1)) {
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
out_strides[2] = out.strides()[ndim - 1];
// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
bool single = in.flags().row_contiguous && B == 1 && T == 1;
bool with_freqs = inputs.size() == 3;
auto& encoder = cu::get_command_encoder(s);
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
std::log2(base_),
strides,
out_strides,
in.size() / mat_size,
offset_stride,
N,
dims);
}
});

View File

@@ -4,7 +4,6 @@
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
@@ -46,6 +45,7 @@ __global__ void kernel_sdpav_1pass(
const T* K,
const T* V,
T* O,
const T* sinks,
__grid_constant__ const AttnParams params) {
constexpr int BN = 32;
constexpr int BD = 32;
@@ -65,7 +65,7 @@ __global__ void kernel_sdpav_1pass(
__shared__ U max_scores[BN];
__shared__ U sum_exp_scores[BN];
const U scale_log2 = params.scale * 1.44269504089f;
const U scale_log2 = params.scale * M_LOG2E;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
@@ -108,8 +108,12 @@ __global__ void kernel_sdpav_1pass(
o[i] = 0.f;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += BN) {
@@ -167,7 +171,7 @@ __global__ void kernel_sdpav_1pass(
U factor = exp2f(max_score - new_max);
sum_exp_score =
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
// Now we need to aggregate all the outputs
PRAGMA_LOOP_UNROLL
@@ -193,6 +197,7 @@ __global__ void kernel_sdpav_2pass_1(
const T* Q,
const T* K,
const T* V,
const T* sinks,
float* partials,
float* sums,
float* maxs,
@@ -268,8 +273,12 @@ __global__ void kernel_sdpav_2pass_1(
o[i] = 0.f;
}
U max_score = -1e9;
U max_score = Limits<U>::finite_min();
U sum_exp_score = 0.f;
if (sinks && warp_idx == 0 && block_idx == 0) {
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
sum_exp_score = 1.f;
}
// For each key
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
@@ -410,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2(
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
U factor = exp2f(max_score - new_max);
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
sum_exp_score = __frcp_rn(sum_exp_score);
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
@@ -463,10 +472,14 @@ void sdpa_vector_1pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(o);
cu::AttnParams params{
@@ -489,7 +502,7 @@ void sdpa_vector_1pass_fallback(
dim3 block_dim(1024, 1, 1);
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -504,6 +517,7 @@ void sdpa_vector_1pass_fallback(
k.data<DataType>(),
v.data<DataType>(),
o.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
params);
});
});
@@ -518,7 +532,8 @@ void sdpa_vector_2pass_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
cu::AttnParams params{
/* int B = */ q.shape(0),
/* int H = */ q.shape(1),
@@ -559,7 +574,7 @@ void sdpa_vector_2pass_fallback(
encoder.add_temporary(maxs);
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
dispatch_bool(do_causal_, [&](auto do_causal) {
dispatch_bool(do_causal, [&](auto do_causal) {
dispatch_headdim(params.D, [&](auto headdim) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -570,6 +585,10 @@ void sdpa_vector_2pass_fallback(
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
if (sinks) {
encoder.set_input_array(*sinks);
}
encoder.set_output_array(intermediate);
encoder.set_output_array(sums);
encoder.set_output_array(maxs);
@@ -585,6 +604,7 @@ void sdpa_vector_2pass_fallback(
q.data<DataType>(),
k.data<DataType>(),
v.data<DataType>(),
sinks ? (*sinks).data<DataType>() : nullptr,
intermediate.data<float>(),
sums.data<float>(),
maxs.data<float>(),
@@ -627,15 +647,16 @@ void sdpa_vector_fallback(
const array& v,
const float scale,
array& o,
bool do_causal_ = false) {
bool do_causal,
const std::optional<array>& sinks) {
int kL = k.shape(2);
if (kL > 1024) {
return sdpa_vector_2pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
} else {
return sdpa_vector_1pass_fallback(
s, encoder, q, k, v, scale, o, do_causal_);
s, encoder, q, k, v, scale, o, do_causal, sinks);
}
}
@@ -691,7 +712,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -703,6 +724,16 @@ void ScaledDotProductAttention::eval_gpu(
}
};
// Checks that the headdim dimension has stride 1.
auto is_matrix_contiguous = [](const array& arr) {
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
// We are in vector mode ie single query
if (q_pre.shape(2) < 4) {
auto q_copy_unless = [](const array& arr) {
@@ -740,10 +771,6 @@ void ScaledDotProductAttention::eval_gpu(
const auto& k = copy_unless(kv_copy_unless, k_pre);
const auto& v = copy_unless(kv_copy_unless, v_pre);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
// Donate the query if possible
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
o.copy_shared_buffer(q);
@@ -752,22 +779,26 @@ void ScaledDotProductAttention::eval_gpu(
int64_t str_oH = o.shape(3);
int64_t str_oL = o.shape(1) * str_oH;
int64_t str_oB = o.shape(2) * str_oL;
size_t data_size = o.shape(0) * str_oB;
array::Flags flags{
/* bool contiguous = */ 1,
/* bool row_contiguous = */ o.shape(2) == 1,
/* bool col_contiguous = */ 0,
/* bool col_contiguous = */ o.size() == o.shape(3),
};
o.set_data(
allocator::malloc(o.nbytes()),
data_size,
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);
}
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
for (const auto& cp : copies) {
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
}
// Full attention mode should never reach here

View File

@@ -1,8 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/dtype_utils.h"
#include <numeric>
@@ -27,8 +30,7 @@ void concatenate_gpu(
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
auto concurrent = cu::get_command_encoder(s).concurrent_context();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];
@@ -38,4 +40,71 @@ void concatenate_gpu(
}
}
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s) {
Dtype dtype = indices.dtype();
int nidx = axes.size();
std::string module_name =
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
std::string kernel_name = fmt::format(
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
dtype_to_cuda_type(dtype),
nidx);
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::string source = R"(
#include "mlx/backend/cuda/device/utils.cuh"
namespace mlx::core::cu {
template <typename T, int NIDX>
__global__ void compute_dynamic_offset(
const T* indices,
int64_t* offset,
const __grid_constant__ Strides strides,
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
int64_t acc = 0;
#pragma unroll
for (int i = 0; i < NIDX; ++i) {
acc += indices[i] * strides[axes[i]];
}
*offset = acc;
}
} // namespace mlx::core::cu
)";
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
});
// Prepare output.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(offset);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
cu::KernelArgs args;
args.append(indices);
args.append(offset);
args.append_ndim(strides);
args.append(axes);
auto kernel = mod.get_kernel(kernel_name);
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
return offset;
}
} // namespace mlx::core

View File

@@ -9,7 +9,7 @@
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
#include <thrust/transform.h>
#include <cub/device/device_segmented_sort.cuh>
#include <cub/device/device_segmented_radix_sort.cuh>
#include <cassert>
@@ -79,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.add_temporary(discard);
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
nullptr,
size,
in.data<Type>(),
@@ -90,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -104,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
thrust::device_pointer_cast(indices.data<uint32_t>()),
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
temp.data<void>(),
size,
in.data<Type>(),
@@ -115,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
} else {
size_t size;
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
size,
in.data<Type>(),
@@ -127,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
@@ -134,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
// Start capturing after allocations
auto capture = encoder.capture_context();
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
temp.data<void>(),
size,
in.data<Type>(),
@@ -143,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data_size() / nsort,
offsets,
offsets + 1,
0,
sizeof(Type) * 8,
stream));
}
} else {

View File

@@ -7,6 +7,7 @@ namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {

View File

@@ -3,7 +3,6 @@
#pragma once
#include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional>

View File

@@ -20,8 +20,8 @@ void copy_gpu_inplace(
int64_t o_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
std::optional<array> dynamic_i_offset = std::nullopt,
std::optional<array> dynamic_o_offset = std::nullopt);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);

View File

@@ -80,6 +80,74 @@ void Depends::eval_gpu(
eval(inputs, outputs);
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* std::optional<array> dynamic_i_offset = */ std::move(in_offset),
/* std::optional<array> dynamic_o_offset = */ std::nullopt);
}
void DynamicSliceUpdate::eval_gpu(
const std::vector<array>& inputs,
array& out) {
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
auto& start_indices = inputs[2];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Copy or donate input to output
auto s = stream();
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
auto out_offset =
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* std::optional<array> dynamic_i_offset = */ std::nullopt,
/* std::optional<array> dynamic_o_offset = */ std::move(out_offset));
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);

View File

@@ -27,4 +27,10 @@ void pad_gpu(
const Shape& low_pad_size,
const Stream& s);
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s);
} // namespace mlx::core

View File

@@ -33,10 +33,11 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
make_jit_source(indexing/gather kernels/indexing/indexing.h)
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
make_jit_source(indexing/gather_axis)
make_jit_source(indexing/scatter_axis)
make_jit_source(hadamard)
if(MLX_METAL_JIT)
@@ -77,7 +78,10 @@ if(MLX_METAL_JIT)
make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h)
make_jit_source(quantized)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(gemv_masked)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -2,7 +2,6 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
std::string kname;
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn,
"_channel_",
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
"_filter_",
small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
kname,
out,
bm,
bn,
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_weight_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0);
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_input_transform_",
type_to_name(out),
"_bc",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
std::string kname;
kname.reserve(32);
concatenate(
kname,
"winograd_conv_2d_output_transform_",
type_to_name(out),
"_bo",
bc);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0);
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
std::string base_name;
base_name.reserve(32);
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
hash_name.reserve(64);
concatenate(
hash_name,
base_name,
"_ker_h_", ker_h,
"_ker_w_", ker_w,
"_str_h_", str_h,
"_str_w_", str_w,
"_tgp_h_", th,
"_tgp_w_", tw,
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
}
}
void depthwise_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
array wt,
array out) {
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
std::string base_name;
base_name.reserve(32);
concatenate(
base_name,
"depthwise_conv_1d_",
large ? "_large" : "",
type_to_name(out));
if (!wt.flags().row_contiguous) {
wt = contiguous_copy_gpu(wt, s);
d.add_temporary(wt, s.index);
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name);
compute_encoder.set_compute_pipeline_state(kernel);
auto B = in.shape(0);
auto Tout = out.shape(1);
auto D = in.shape(2);
auto K = wt.shape(1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
if (large) {
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
compute_encoder.set_bytes(strides, 3, 3);
} else {
int strides[3] = {
static_cast<int>(in.strides(0)),
static_cast<int>(in.strides(1)),
static_cast<int>(in.strides(2))};
compute_encoder.set_bytes(strides, 3, 3);
}
compute_encoder.set_bytes(K, 4);
auto group_dims = get_block_dims(D, Tout, B);
MTL::Size grid_dims = MTL::Size(D, Tout, B);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -790,8 +873,15 @@ void conv_1D_gpu(
bool is_idil_one = in_dilation[0] == 1;
int C = in.shape(2);
int O = wt.shape(0);
const int C_per_group = in.shape(2) / groups;
const int O_per_group = wt.shape(0) / groups;
// Fast path for fully separable 1D convolution
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
depthwise_conv_1D_gpu(s, d, in, wt, out);
return;
}
const int C_per_group = C / groups;
const int O_per_group = O / groups;
// Direct to implicit gemm conv
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&

View File

@@ -20,8 +20,8 @@ void copy_gpu_inplace(
int64_t out_offset,
CopyType ctype,
const Stream& s,
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
std::optional<array> dynamic_i_offset /* = std::nullopt */,
std::optional<array> dynamic_o_offset /* = std::nullopt */) {
if (out.size() == 0) {
return;
}

View File

@@ -172,7 +172,7 @@ std::string write_template(
return template_def.str();
}
MetalKernelFunction metal_kernel(
CustomKernelFunction metal_kernel(
const std::string& name,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
@@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
threadgroup,
shape_infos,
ensure_row_contiguous,
init_value),
init_value,
std::vector<ScalarArg>{},
false,
0),
std::move(inputs));
};
}

View File

@@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4];
NS::Error* error[5];
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
@@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
return lib;
}
// Try lo load resources from Framework resources if SwiftPM wrapped as a
// dynamic framework.
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) {
for (int i = 0; i < 5; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}

View File

@@ -52,8 +52,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX;
@@ -61,6 +63,55 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&
inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {
int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;
auto& indices = inputs[1];
std::string kernel_name = fmt::format(
"gather_front{0}_{1}_{2}_{3}",
type_to_name(out),
idx_type_name,
large ? "int64_t" : "int",
work_per_thread);
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather_front();
kernel_source += get_template_definition(
kernel_name,
"gather_front",
get_type_string(out.dtype()),
get_type_string(indices.dtype()),
large ? "int64_t" : "int",
work_per_thread);
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
size_t dim_y = indices.size();
auto group_dims = get_block_dims(dim_x, dim_y, 1);
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
compute_encoder.set_input_array(src, 0);
compute_encoder.set_input_array(indices, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(slice_size, 3);
compute_encoder.set_bytes(src.shape(0), 4);
compute_encoder.dispatch_threads(grid_dims, group_dims);
return;
}
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
@@ -96,11 +147,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder.set_compute_pipeline_state(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
// Launch 3D grid of threads
// First two dimensions for the indices, the last one for the slice
size_t dim0 = 1;

View File

@@ -19,9 +19,12 @@ const char* binary_two();
const char* copy();
const char* fft();
const char* gather_axis();
const char* gather_front();
const char* hadamard();
const char* logsumexp();
const char* quantized_utils();
const char* quantized();
const char* fp4_quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();

View File

@@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const std::string& template_def,
const std::string& mode) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
return kernel_source.str();
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
template_def);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x,
int group_size,
int bits,
const std::string& mode,
int bm,
int bn,
int bk,
@@ -833,22 +840,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
concatenate(
kernel_source,
metal::quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def);
const std::string& template_def,
const std::string& mode);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x,
int group_size,
int bits,
const std::string& mode,
int bm,
int bn,
int bk,

View File

@@ -108,7 +108,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

View File

@@ -223,6 +223,11 @@ struct Power {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
// Undefined to raise integer to negative power
if (exp < 0) {
return 0;
}
while (exp) {
if (exp & 1) {
res *= base;

View File

@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
instantiate_depthconv2d(float16, half);
instantiate_depthconv2d(bfloat16, bfloat16_t);
template <typename T, typename IdxT>
[[kernel]] void depthwise_conv_1d(
const device T* in [[buffer(0)]],
const device T* w [[buffer(1)]],
device T* out [[buffer(2)]],
constant const IdxT strides[3],
constant const int& kernel_size,
uint3 tid [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
w += tid.x * kernel_size;
float acc = 0.0;
for (int i = 0; i < kernel_size; ++i) {
acc += static_cast<float>(in[0]) * w[i];
in += strides[1];
}
*out = static_cast<T>(acc);
}
#define instantiate_depthconv1d(iname, itype) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
instantiate_kernel( \
"depthwise_conv_1d_" #iname "_large", \
depthwise_conv_1d, \
itype, \
int64_t)
instantiate_depthconv1d(float32, float);
instantiate_depthconv1d(float16, half);
instantiate_depthconv1d(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
uint8_t, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
aligned, \
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
uint8_t, \
D, \
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
uint8_t, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -35,8 +35,8 @@ struct GEMVKernel {
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
SN == 4 || SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 4, 8, 16, or 32");
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
@@ -511,17 +511,21 @@ template <
axpby)
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/indexing/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl(

View File

@@ -0,0 +1,24 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing/indexing.h"
template <typename T, typename IdxT, typename LocT, int N>
[[kernel]] void gather_front(
const device T* src,
const device IdxT* indices,
device T* out,
const constant int64_t& stride,
const constant int& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto idx = offset_neg_idx(indices[index.y], size);
LocT src_idx = static_cast<LocT>(stride) * idx;
LocT out_idx = static_cast<LocT>(stride) * index.y;
int s_idx = N * index.x;
for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
out[out_idx + s_idx] = src[src_idx + s_idx];
}
}

View File

@@ -2,7 +2,7 @@
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/indexing/indexing.h"
template <
typename T,

View File

@@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets(
}
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
[[kernel]] void affine_qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1486,7 +1486,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(
[[kernel]] void affine_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1538,7 +1538,7 @@ template <typename T, int group_size, int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(
[[kernel]] void affine_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1590,7 +1590,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(
[[kernel]] void affine_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1642,7 +1642,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
[[kernel]] void affine_qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1706,7 +1706,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_t(
[[kernel]] void affine_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1764,7 +1764,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_n(
[[kernel]] void affine_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1817,7 +1817,7 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv_fast(
[[kernel]] void affine_gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1879,7 +1879,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv(
[[kernel]] void affine_gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -1941,7 +1941,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qvm(
[[kernel]] void affine_gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2010,7 +2010,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_t(
[[kernel]] void affine_gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2077,7 +2077,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_n(
[[kernel]] void affine_gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@@ -2138,92 +2138,6 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
template <
typename T,
int group_size,
@@ -2234,7 +2148,7 @@ template <
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
[[kernel]] void affine_gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],

View File

@@ -3,6 +3,7 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
@@ -79,40 +80,40 @@
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(affine_gather_qmv, type, group_size, bits) \
instantiate_quantized(affine_gather_qvm, type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \

View File

@@ -0,0 +1,90 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_simdgroup>
#include <metal_stdlib>
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_aligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <
bool rows_aligned,
bool cols_aligned,
bool transpose,
typename T,
typename mma_t,
typename loader_a_t,
typename loader_b_t>
METAL_FUNC void gemm_loop_unaligned(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const int k_iterations,
const short tgp_bm,
const short tgp_bn,
const short tgp_bk) {
for (int k = 0; k < k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup memory
if (rows_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(short2(tgp_bk, tgp_bm));
}
if (cols_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
METAL_FUNC void gemm_loop_finalize(
threadgroup T* As,
threadgroup T* Bs,
thread mma_t& mma_op,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
const short2 tile_a,
const short2 tile_b) {
loader_a.load_safe(tile_a);
loader_b.load_safe(tile_b);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}

View File

@@ -3,14 +3,19 @@
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
constant bool forward [[function_constant(1)]];
constant bool traditional [[function_constant(2)]];
constant bool hs_transpose [[function_constant(3)]];
template <typename T>
void rope_single_impl(
const device T* in,
device T* out,
constant const int& offset,
const float inv_freq,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
uint2 pos,
uint2 grid) {
float L = scale * static_cast<float>(offset);
@@ -46,76 +51,85 @@ void rope_single_impl(
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
constant const float& base [[buffer(10)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& scale,
constant const size_t& stride,
constant const int64_t& stride,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
void rope_impl(
const device T* in,
device T* out,
constant const int& offset,
const device int* offset,
const float inv_freq,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
uint3 pos,
uint3 grid) {
float L = scale * static_cast<float>(pos.y + offset);
auto n_head_up = N * ((n_head + N - 1) / N);
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
auto batch_idx = (pos.z * N) / n_head_up;
auto batch_offset = offset[batch_idx * offset_stride];
float L = scale * static_cast<float>(pos.y + batch_offset);
auto mat_idx = batch_idx * n_head + head_idx;
// Compute costheta, sintheta
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
IdxT in_index_1;
if (hs_transpose) {
IdxT batch_stride = grid.y * IdxT(strides[1]);
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
IdxT in_index_2;
IdxT out_index_1 =
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
IdxT out_index_2;
if (traditional) {
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + 1;
in_index_1 += 2 * pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + IdxT(strides[2]);
} else {
out_index_1 += pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
in_index_1 += pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -130,28 +144,29 @@ void rope_impl(
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
in_index_1 += IdxT(strides[0]);
in_index_2 += IdxT(strides[0]);
out_index_1 += IdxT(out_strides[0]);
out_index_2 += IdxT(out_strides[0]);
}
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
constant const float& base [[buffer(10)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -159,26 +174,28 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
const device int* offset,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
constant const int64_t strides[3],
constant const int64_t out_strides[3],
constant const int64_t& offset_stride,
constant const int& n_head,
const device float* freqs [[buffer(10)]],
constant const size_t& freq_stride [[buffer(11)]],
constant const int64_t& freq_stride [[buffer(11)]],
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -186,75 +203,27 @@ template <typename T, bool traditional, bool forward, int N = 4>
scale,
strides,
out_strides,
n_batch,
offset_stride,
n_head,
pos,
grid);
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
constant const float& base [[buffer(10)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); \
template [[host_name("rope_freqs_" #name)]] \
[[kernel]] void rope_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_g(name, type) \
instantiate_kernel("rope_" #name, rope, type, int32_t) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
constant const float& base [[buffer(10)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]); \
template [[host_name("rope_single_freqs_" #name)]] \
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& scale, \
constant const size_t& stride, \
const device float* freqs [[buffer(10)]], \
constant const size_t& freq_stride [[buffer(11)]], \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type) \
instantiate_kernel("rope_single_" #name, rope_single, type) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
#define instantiate_rope(name, type) \
instantiate_rope_s(name, type) \
instantiate_rope_g(name, type)
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on
instantiate_rope(float16, half)
instantiate_rope(bfloat16, bfloat16_t)
instantiate_rope(float32, float) // clang-format on

View File

@@ -306,6 +306,7 @@ template <
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
@@ -440,6 +441,7 @@ template <
}
// Read in SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
read_into[i] = in[index_y * stride + i];

View File

@@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -31,6 +32,9 @@ template <typename T, int D, int V = D>
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(17), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -53,24 +57,24 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
bmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
fmask += q_batch_head_idx * mask_head_stride +
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
out += o_offset * V + simd_gid * v_per_thread;
@@ -83,8 +87,12 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -INFINITY;
U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && simd_gid == 0) {
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
sum_exp_score = 1;
}
// For each key
for (int i = simd_gid; i < N; i += BN) {
@@ -93,6 +101,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -107,7 +117,7 @@ template <typename T, int D, int V = D>
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
score += static_cast<U>(fmask[0]);
}
// Update the accumulators
@@ -152,7 +162,8 @@ template <typename T, int D, int V = D>
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@@ -187,6 +198,9 @@ template <typename T, int D, int V = D>
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -211,12 +225,12 @@ template <typename T, int D, int V = D>
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = head_idx * tpg.y + q_seq_idx;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset;
const int kv_head_idx = head_idx / gqa_factor;
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride +
@@ -225,12 +239,12 @@ template <typename T, int D, int V = D>
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride +
bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
@@ -245,8 +259,13 @@ template <typename T, int D, int V = D>
o[i] = 0;
}
U max_score = -1e9;
U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
@@ -255,6 +274,8 @@ template <typename T, int D, int V = D>
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
@@ -268,6 +289,7 @@ template <typename T, int D, int V = D>
score += q[i] * k[i];
}
score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
@@ -379,7 +401,8 @@ template <typename T, int D>
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup);
}

View File

@@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T>
struct TransformScale {
@@ -82,6 +83,7 @@ template <
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -169,7 +171,7 @@ template <
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
TransformScale<T> ts(static_cast<T>(params->scale * M_LOG2E_F));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
@@ -232,6 +234,14 @@ template <
max_score[i] = Limits<AccumType>::finite_min;
}
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK;
if (do_causal) {
@@ -350,7 +360,7 @@ template <
Stile.frag_at(i, j)[jj] =
mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
} else {
Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]);
}
}
}

View File

@@ -698,6 +698,15 @@ void gemv_axbpy(
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
if (K <= 64) {
bm = 1;
sm = 8;
sn = 4;
} else if (K >= 16 * out_vector_len) {
bm = 1;
bn = 8;
}
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;

View File

@@ -26,15 +26,15 @@ device_info() {
namespace fast {
MetalKernelFunction metal_kernel(
CustomKernelFunction metal_kernel(
const std::string&,
const std::vector<std::string>&,
const std::vector<std::string>&,
const std::string&,
const std::string&,
bool ensure_row_contiguous,
bool atomic_outputs) {
throw std::runtime_error("[metal_kernel] No GPU back-end.");
bool,
bool) {
throw std::runtime_error("[metal_kernel] No Metal back-end.");
}
} // namespace fast

View File

@@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&) {
return d.get_kernel(kernel_name);
}
@@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array&,
int,
int,
const std::string&,
int,
int,
int,

View File

@@ -4,7 +4,6 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
@@ -25,60 +24,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(step, 1);
}
static array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
Stream s) {
auto& d = metal::device(s.device);
// Kernel to compute offset here.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
d.add_temporary(offset, s.index);
auto dtype = indices.dtype();
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
auto lib = d.get_library(lib_name, [dtype]() {
return fmt::format(
R"(
[[kernel]] void compute_dynamic_offset_{0}(
constant const {1}* indices [[buffer(0)]],
device int64_t& offset [[buffer(1)]],
constant const int64_t* strides [[buffer(2)]],
constant const int* axes [[buffer(3)]],
constant const int& n_axes [[buffer(4)]],
uint index [[thread_position_in_grid]]) {{
int64_t acc = 0;
for (int i = 0; i < n_axes; ++i) {{
acc += indices[i] * strides[axes[i]];
}}
offset = acc;
}})",
type_to_name(dtype),
get_type_string(dtype));
});
auto kernel = d.get_kernel(lib_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(indices, 0);
compute_encoder.set_output_array(offset, 1);
compute_encoder.set_vector_bytes(strides, 2);
compute_encoder.set_vector_bytes(axes, 3);
int n_axes = axes.size();
compute_encoder.set_bytes(n_axes, 4);
MTL::Size dims = MTL::Size(1, 1, 1);
compute_encoder.dispatch_threads(dims, dims);
return offset;
}
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
@@ -256,72 +201,6 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& start = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
auto s = stream();
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
}
void DynamicSliceUpdate::eval_gpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
auto& start_indices = inputs[2];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Copy or donate input to output
auto s = stream();
auto& d = metal::device(s.device);
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
auto out_offset =
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
copy_gpu_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const Shape& data_shape = */ upd.shape(),
/* const Strides& i_strides = */ upd.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ s,
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
}
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -1,7 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
@@ -17,6 +15,28 @@ namespace mlx::core {
namespace {
template <typename... Args>
auto get_quantized_kernel_wrapped(
metal::Device& d,
const std::string& name,
const std::string& func,
const std::string& mode,
const std::string& type,
int group_size,
int bits,
Args... args) {
std::string template_def;
auto fname = mode + "_" + func;
if (mode == "affine") {
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
} else {
template_def = get_template_definition(
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
}
return get_quantized_kernel(d, name, template_def, mode);
}
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
@@ -99,7 +119,7 @@ inline int add_strides_and_shapes(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
int offset) {
if (skip) {
return 0;
@@ -109,16 +129,18 @@ inline int add_strides_and_shapes(
int x_batch_ndims = x.ndim() - 2;
int w_batch_ndims = w.ndim() - 2;
compute_encoder.set_bytes(x_batch_ndims, offset);
compute_encoder.set_vector_bytes(x.shape(), offset + 1);
compute_encoder.set_vector_bytes(x.strides(), offset + 2);
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
compute_encoder.set_vector_bytes(w.shape(), offset + 4);
compute_encoder.set_vector_bytes(w.strides(), offset + 5);
compute_encoder.set_vector_bytes(scales.strides(), offset + 6);
compute_encoder.set_vector_bytes(biases.strides(), offset + 7);
compute_encoder.set_bytes(x_batch_ndims, offset++);
compute_encoder.set_vector_bytes(x.shape(), offset++);
compute_encoder.set_vector_bytes(x.strides(), offset++);
compute_encoder.set_bytes(w_batch_ndims, offset++);
compute_encoder.set_vector_bytes(w.shape(), offset++);
compute_encoder.set_vector_bytes(w.strides(), offset++);
compute_encoder.set_vector_bytes(scales.strides(), offset++);
if (biases) {
compute_encoder.set_vector_bytes(biases->strides(), offset++);
}
return 8;
return offset;
}
inline int add_gather_strides_and_shapes(
@@ -130,12 +152,12 @@ inline int add_gather_strides_and_shapes(
lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});
int ndims = shape.size();
compute_encoder.set_bytes(ndims, offset);
compute_encoder.set_vector_bytes(shape, offset + 1);
compute_encoder.set_vector_bytes(strides[0], offset + 2);
compute_encoder.set_vector_bytes(strides[1], offset + 3);
compute_encoder.set_bytes(ndims, offset++);
compute_encoder.set_vector_bytes(shape, offset++);
compute_encoder.set_vector_bytes(strides[0], offset++);
compute_encoder.set_vector_bytes(strides[1], offset++);
return 4;
return offset;
}
} // namespace
@@ -144,7 +166,7 @@ void qmv_quad(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -152,7 +174,8 @@ void qmv_quad(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
constexpr int quads_per_simd = 8;
@@ -165,9 +188,10 @@ void qmv_quad(
std::string kname;
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qmv_quad_",
mode + "_qmv_quad_",
type_string,
"_gs_",
group_size,
@@ -176,21 +200,22 @@ void qmv_quad(
"_d_",
K,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qmv_quad", type_string, group_size, bits, K, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -199,7 +224,7 @@ void qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -207,7 +232,8 @@ void qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@@ -219,30 +245,40 @@ void qmv(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "qmv_fast_" : "qmv_",
mode + (fast ? "_qmv_fast_" : "_qmv_"),
type_string,
"_gs_",
group_size,
"_b_",
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
(fast ? "qmv_fast" : "qmv"),
mode,
type_string,
group_size,
bits,
B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -251,7 +287,7 @@ void qvm_split_k(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -259,7 +295,8 @@ void qvm_split_k(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int split_k = K > 8192 ? 32 : 8;
int split_D = (K + split_k - 1) / split_k;
int B = out.size() / M / N;
@@ -283,7 +320,6 @@ void qvm_split_k(
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
@@ -297,7 +333,6 @@ void qvm_split_k(
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = K - (split_k - 1) * split_D;
@@ -315,7 +350,7 @@ void qvm_split_k(
kname.reserve(64);
concatenate(
kname,
"qvm_split_k_",
mode + "_qvm_split_k_",
type_string,
"_gs_",
group_size,
@@ -323,31 +358,38 @@ void qvm_split_k(
bits,
"_spk_",
split_k);
auto template_def = get_template_definition(
kname, "qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder.set_bytes(split_D, 5);
compute_encoder.set_bytes(N, 6);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(intermediate, c++);
compute_encoder.set_bytes(split_D, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
compute_encoder.set_vector_bytes(b_strides, 14);
compute_encoder.set_bytes(final_block_size, 15);
compute_encoder.set_bytes(x_batch_ndims, c++);
compute_encoder.set_vector_bytes(x_shape, c++);
compute_encoder.set_vector_bytes(x_strides, c++);
compute_encoder.set_bytes(w_batch_ndims, c++);
compute_encoder.set_vector_bytes(w_shape, c++);
compute_encoder.set_vector_bytes(w_strides, c++);
compute_encoder.set_vector_bytes(s_strides, c++);
if (biases) {
auto b_strides = biases->strides();
b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1));
compute_encoder.set_vector_bytes(b_strides, c++);
}
compute_encoder.set_bytes(final_block_size, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -364,7 +406,7 @@ void qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@@ -372,7 +414,8 @@ void qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@@ -385,28 +428,29 @@ void qvm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qvm_",
mode + "_qvm_",
type_string,
"_gs_",
group_size,
"_b_",
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qvm", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qvm", mode, type_string, group_size, bits, B > 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -415,7 +459,7 @@ void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
bool transpose,
int group_size,
@@ -424,7 +468,8 @@ void qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@@ -441,7 +486,7 @@ void qmm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "qmm_t_" : "qmm_n_",
mode + (transpose ? "_qmm_t_" : "_qmm_n_"),
type_string,
"_gs_",
group_size,
@@ -450,27 +495,37 @@ void qmm(
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
batched ? "_batch_1" : "_batch_0");
std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) {
template_def = get_template_definition(
kname, "qmm_t", type_string, group_size, bits, aligned, batched);
kernel = get_quantized_kernel_wrapped(
d,
kname,
"qmm_t",
mode,
type_string,
group_size,
bits,
aligned,
batched);
} else {
template_def = get_template_definition(
kname, "qmm_n", type_string, group_size, bits, batched);
kernel = get_quantized_kernel_wrapped(
d, kname, "qmm_n", mode, type_string, group_size, bits, batched);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
compute_encoder.set_bytes(M, 7);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -479,7 +534,7 @@ void gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -490,7 +545,8 @@ void gather_qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@@ -503,44 +559,43 @@ void gather_qmm(
std::string kname;
kname.reserve(64);
bool aligned = N % 32 == 0;
bool batched = B > 1;
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_t_" : "gather_qmm_n_",
mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"),
type_string,
"_gs_",
group_size,
"_b_",
bits,
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) {
template_def = get_template_definition(
kname, "gather_qmm_t", type_string, group_size, bits, aligned);
kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned);
} else {
template_def = get_template_definition(
kname, "gather_qmm_n", type_string, group_size, bits);
kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qmm_n", mode, type_string, group_size, bits);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
compute_encoder.set_bytes(M, 9);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 10 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -549,7 +604,7 @@ void gather_qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -559,7 +614,8 @@ void gather_qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@@ -573,36 +629,39 @@ void gather_qmv(
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "gather_qmv_fast_" : "gather_qmv_",
mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"),
type_string,
"_gs_",
group_size,
"_b_",
bits);
auto template_def = get_template_definition(
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
fast ? "gather_qmv_fast" : "gather_qmv",
(fast ? "gather_qmv_fast" : "gather_qmv"),
mode,
type_string,
group_size,
bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -611,7 +670,7 @@ void gather_qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@@ -621,7 +680,8 @@ void gather_qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@@ -633,27 +693,32 @@ void gather_qvm(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits);
auto template_def = get_template_definition(
kname, "gather_qvm", type_string, group_size, bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
kname,
mode + "_gather_qvm_",
type_string,
"_gs_",
group_size,
"_b_",
bits);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qvm", mode, type_string, group_size, bits);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -662,7 +727,7 @@ void gather_qmm_rhs(
const array& x_,
const array& w_,
const array& scales_,
const array& biases_,
const std::optional<array>& biases_,
const array& indices_,
array& out,
bool transpose,
@@ -672,7 +737,8 @@ void gather_qmm_rhs(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string mode) {
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);
@@ -697,7 +763,6 @@ void gather_qmm_rhs(
array x = broadcast_with_indices(x_);
array w = ensure_row_contiguous(w_, d, s);
array scales = ensure_row_contiguous(scales_, d, s);
array biases = ensure_row_contiguous(biases_, d, s);
// TODO: Tune the block sizes
int bm = 16, bn = 32, bk = 32;
@@ -713,7 +778,7 @@ void gather_qmm_rhs(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_",
mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"),
type_string,
"_gs_",
group_size,
@@ -759,6 +824,7 @@ void gather_qmm_rhs(
x,
group_size,
bits,
mode,
bm,
bn,
bk,
@@ -770,15 +836,19 @@ void gather_qmm_rhs(
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(indices, 4);
compute_encoder.set_output_array(out, 5);
compute_encoder.set_bytes(M, 6);
compute_encoder.set_bytes(N, 7);
compute_encoder.set_bytes(K, 8);
int c = 0;
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases_) {
array biases = ensure_row_contiguous(*biases_, d, s);
compute_encoder.set_input_array(biases, c++);
}
compute_encoder.set_input_array(indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(M, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@@ -794,7 +864,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
std::optional<array> biases = std::nullopt;
if (inputs.size() == 4) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
@@ -803,7 +876,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = out.shape(-1);
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);
// It is a matrix matrix product.
if (M >= vector_limit) {
qmm(x,
@@ -818,30 +891,33 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
// It is a qmv with a small inner dimension so route to qmv_quad kernel
if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) {
qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv_quad(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Run of the mill qmv
if (transpose_) {
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Run of the mill qvm
if (K < 1024) {
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
// Qvm with large dimension so route to a split K kernel for more parallelism
qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm_split_k(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode);
return;
}
@@ -854,9 +930,12 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
const array& lhs_indices = inputs[4];
const array& rhs_indices = inputs[5];
std::optional<array> biases = std::nullopt;
if (inputs.size() == 6) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
const array& lhs_indices = inputs[inputs.size() - 2];
const array& rhs_indices = inputs[inputs.size() - 1];
int K = x.shape(-1);
int M = x.shape(-2);
@@ -864,12 +943,13 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int B = out.size() / M / N;
int E = w.size() / w.shape(-1) / w.shape(-2);
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
auto mode = quantization_mode_to_string(mode_);
// We are walking x in order and w is also in order so we can batch up the
// matmuls and reuse reading x and w.
//
// TODO: Tune 16 and 8 here a bit better.
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) {
// TODO: Tune 16 and 4 here a bit better.
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 4) {
gather_qmm_rhs(
x,
w,
@@ -884,7 +964,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -905,7 +986,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -924,7 +1006,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
return;
}
@@ -942,10 +1025,11 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode);
}
void fast::AffineQuantize::eval_gpu(
void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
@@ -974,15 +1058,27 @@ void fast::AffineQuantize::eval_gpu(
compute_encoder.set_output_array(biases, 3);
}
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
std::string kname;
concatenate(
kname,
dequantize_ ? "affine_dequantize" : "affine_quantize",
"_",
type_string,
"_gs_",
group_size_,
"_b_",
bits_);
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
dequantize_ ? "dequantize" : "quantize",
"affine",
type_string,
group_size_,
bits_);
compute_encoder.set_compute_pipeline_state(kernel);
// Treat uint32 as uint8 in kernel

View File

@@ -18,23 +18,32 @@ void RoPE::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
size_t out_strides[3];
int64_t strides[3];
int64_t out_strides[3];
bool donated = false;
int ndim = in.ndim();
int dispatch_ndim = in.ndim();
int B = in.shape(0);
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
dispatch_ndim--;
}
size_t mat_size = in.shape(-2) * in.shape(-1);
if (dims_ < in.shape(-1)) {
int N = 1;
for (int i = 1; i < (ndim - 2); ++i) {
N *= in.shape(i);
}
bool head_seq_transpose = false;
if (dims_ < D) {
donated = true;
auto ctype =
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
@@ -58,6 +67,17 @@ void RoPE::eval_gpu(
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (
ndim == 4 &&
// batch dim is regularly strided
in.strides()[0] == T * N * D &&
// sequence and head dimensions are transposed
in.strides()[1] == D && in.strides()[2] == N * D) {
head_seq_transpose = true;
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[1];
strides[1] = in.strides()[2];
strides[2] = in.strides()[3];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
@@ -71,39 +91,62 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
// Special case for inference (single time step, contiguous, one offset)
auto& offset = inputs[1];
bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
std::string kname;
concatenate(
kname,
"rope_",
single ? "single_" : "",
(with_freqs) ? "freqs_" : "",
large ? "large_" : "",
type_to_name(in));
std::string hash_name;
concatenate(
hash_name,
kname,
"_",
forward_ ? "" : "vjp_",
traditional_ ? "traditional_" : "",
head_seq_transpose ? "transpose" : "");
metal::MTLFCList func_consts = {
{&forward_, MTL::DataType::DataTypeBool, 1},
{&traditional_, MTL::DataType::DataTypeBool, 2},
{&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};
auto kernel = d.get_kernel(kname, hash_name, func_consts);
auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_input_array(offset, 2);
compute_encoder.set_bytes(scale_, 3);
size_t n_batch = in.size() / mat_size;
MTL::Size group_dims;
MTL::Size grid_dims;
if (single) {
compute_encoder.set_bytes(out_strides, 1, 4);
uint32_t dim0 = dims_ / 2;
group_dims = get_block_dims(dim0, n_batch, 1);
grid_dims = MTL::Size(dim0, n_batch, 1);
group_dims = get_block_dims(dim0, N, 1);
grid_dims = MTL::Size(dim0, N, 1);
} else {
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
compute_encoder.set_bytes(n_batch, 6);
int64_t offset_stride = 0;
if (offset.ndim() > 0) {
offset_stride = offset.strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
uint32_t dim1 = T;
uint32_t dim2 = B * ((N + n_per_thread - 1) / n_per_thread);
group_dims = get_block_dims(dim0, dim1, dim2);
grid_dims = MTL::Size(dim0, dim1, dim2);
}

View File

@@ -21,8 +21,9 @@ void sdpa_full_self_attention_metal(
const array& v,
const float scale,
array& o,
bool do_causal_ = false,
const std::optional<array>& mask = std::nullopt) {
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
int wm = 4;
@@ -42,35 +43,49 @@ void sdpa_full_self_attention_metal(
const bool align_Q = (qL % bq) == 0;
const bool align_K = (kL % bk) == 0;
const bool has_mask = !!mask;
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301}};
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
std::ostringstream kname;
// clang-format off
kname << "steel_attention_"
<< type_to_name(q)
<< "_bq" << bq
<< "_bk" << bk
<< "_bd" << bd
<< "_wm" << wm
<< "_wn" << wn
<< "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
std::string base_name;
concatenate(
base_name,
"steel_attention_",
type_to_name(q),
"_bq",
bq,
"_bk",
bk,
"_bd",
bd,
"_wm",
wm,
"_wn",
wn,
"_mask",
type_to_name(has_mask ? *mask : q));
std::string base_name = kname.str();
// clang-format off
kname << "_align_Q_" << (align_Q ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_has_mask_" << (has_mask ? 't' : 'n')
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
std::string hash_name;
concatenate(
hash_name,
base_name,
"_align_Q_",
(align_Q ? 't' : 'n'),
"_align_K_",
(align_K ? 't' : 'n'),
"_has_mask_",
(has_mask ? 't' : 'n'),
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
@@ -114,8 +129,8 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
if (mask) {
auto m = *mask;
if (has_mask) {
auto& m = *mask;
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
m.strides(0), m.strides(1), m.strides(2)}};
@@ -123,6 +138,9 @@ void sdpa_full_self_attention_metal(
compute_encoder.set_bytes(mask_params, 5);
compute_encoder.set_input_array(m, 6);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 7);
}
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
@@ -139,7 +157,8 @@ void sdpa_vector(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -153,30 +172,32 @@ void sdpa_vector(
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1);
size_t k_seq_stride = k.strides()[2];
size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1);
size_t v_seq_stride = v.strides()[2];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(B, q.shape(2), 1);
MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -207,6 +228,10 @@ void sdpa_vector(
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 16);
compute_encoder.set_bytes(q.shape(1), 17);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -221,7 +246,8 @@ void sdpa_vector_2pass(
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
const std::optional<array>& mask,
const std::optional<array>& sinks) {
// Set the kernel name
std::string kname;
kname.reserve(64);
@@ -267,17 +293,20 @@ void sdpa_vector_2pass(
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
bool has_sinks = sinks.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
{&has_sinks, MTL::DataType::DataTypeBool, 25},
};
std::string hash_name = kname;
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
hash_name += has_sinks ? "_sinks" : "_nosinks";
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -310,6 +339,10 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
}
if (has_sinks) {
compute_encoder.set_input_array(*sinks, 18);
compute_encoder.set_bytes(q.shape(1), 19);
}
// Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -394,7 +427,7 @@ void ScaledDotProductAttention::eval_gpu(
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(3);
copies.reserve(inputs.size());
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -411,6 +444,12 @@ void ScaledDotProductAttention::eval_gpu(
return arr.strides(-1) == 1;
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
}
bool has_arr_mask = inputs.size() > (3 + has_sinks_);
// We are in vector mode ie single query
if (q_pre.shape(2) <= 8) {
auto q_copy_unless = [](const array& arr) {
@@ -462,7 +501,7 @@ void ScaledDotProductAttention::eval_gpu(
(strides[0] == strides[1] * shape[1]);
};
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(mask_copy_unless, inputs[3])}
: std::nullopt;
@@ -473,9 +512,9 @@ void ScaledDotProductAttention::eval_gpu(
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
}
}
@@ -503,11 +542,12 @@ void ScaledDotProductAttention::eval_gpu(
{str_oB, str_oH, str_oL, str_oD},
flags);
auto mask = inputs.size() > 3
auto mask = has_arr_mask
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
: std::nullopt;
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -36,14 +36,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
bool contiguous = in.strides()[axis_] == 1;
std::ostringstream kname;
kname << (contiguous ? "contig_" : "strided_");
kname << "scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
std::string reduce_type;
switch (reduce_type_) {
case Scan::Sum:
@@ -62,9 +54,22 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
reduce_type = "logaddexp";
break;
}
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(
d, kname.str(), reverse_, inclusive_, reduce_type, in, out);
std::string kname;
concatenate(
kname,
contiguous ? "contig_" : "strided_",
"scan_",
reverse_ ? "reverse_" : "",
(inclusive_) ? "inclusive_" : "exclusive_",
reduce_type,
"_",
type_to_name(in),
"_",
type_to_name(out));
auto kernel =
get_scan_kernel(d, kname, reverse_, inclusive_, reduce_type, in, out);
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -2,9 +2,12 @@
#include <numeric>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
namespace mlx::core {
@@ -39,4 +42,58 @@ void concatenate_gpu(
}
}
array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
const Stream& s) {
auto& d = metal::device(s.device);
// Kernel to compute offset here.
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
d.add_temporary(offset, s.index);
auto dtype = indices.dtype();
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
auto lib = d.get_library(lib_name, [dtype]() {
return fmt::format(
R"(
[[kernel]] void compute_dynamic_offset_{0}(
constant const {1}* indices [[buffer(0)]],
device int64_t& offset [[buffer(1)]],
constant const int64_t* strides [[buffer(2)]],
constant const int* axes [[buffer(3)]],
constant const int& n_axes [[buffer(4)]],
uint index [[thread_position_in_grid]]) {{
int64_t acc = 0;
for (int i = 0; i < n_axes; ++i) {{
acc += indices[i] * strides[axes[i]];
}}
offset = acc;
}})",
type_to_name(dtype),
get_type_string(dtype));
});
auto kernel = d.get_kernel(lib_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(indices, 0);
compute_encoder.set_output_array(offset, 1);
compute_encoder.set_vector_bytes(strides, 2);
compute_encoder.set_vector_bytes(axes, 3);
int n_axes = axes.size();
compute_encoder.set_bytes(n_axes, 4);
MTL::Size dims = MTL::Size(1, 1, 1);
compute_encoder.dispatch_threads(dims, dims);
return offset;
}
} // namespace mlx::core

View File

@@ -129,7 +129,7 @@ NO_CPU(Inverse)
NO_CPU(View)
namespace fast {
NO_CPU_MULTI(AffineQuantize)
NO_CPU_MULTI(Quantize)
} // namespace fast
namespace distributed {

View File

@@ -154,7 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(Quantize)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@@ -727,7 +727,11 @@ void compile_fuse(
}
};
if (arr.has_primitive()) {
// This will be the result of the fused operation so it needs
// a) to not be already computed ie have a primitive
// b) that primitive to not be a broadcast since it will unnecessarily
// cast to a contiguous array potentially blowing up memory
if (arr.has_primitive() && !is_broadcast(arr.primitive())) {
Stream s = arr.primitive().stream();
recurse(arr, 0, s, arr.shape());
}

View File

@@ -6,3 +6,4 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)

View File

@@ -2,15 +2,21 @@
#include <unordered_map>
#include "mlx/backend/cuda/cuda.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/nccl/nccl.h"
#include "mlx/distributed/ring/ring.h"
namespace mlx::core::distributed {
namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
@@ -37,6 +43,10 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override {
return 0;
}
@@ -80,7 +90,7 @@ class EmptyGroup : public GroupImpl {
} // namespace detail
bool is_available() {
return mpi::is_available() || ring::is_available();
return mpi::is_available() || ring::is_available() || nccl::is_available();
}
int Group::rank() const {
@@ -105,15 +115,23 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
}
// Create the requested communication group
std::shared_ptr<detail::GroupImpl> group;
std::shared_ptr<detail::GroupImpl> group{nullptr};
std::string bk_ = bk;
if (bk == "mpi") {
group = mpi::init(strict);
} else if (bk == "ring") {
group = ring::init(strict);
} else if (bk == "nccl") {
group = nccl::init(strict);
} else if (bk == "any") {
group = ring::init(false);
bk_ = "ring";
if (mlx::core::cu::is_available()) {
group = nccl::init(false);
bk_ = "nccl";
}
if (group == nullptr) {
group = ring::init(false);
bk_ = "ring";
}
if (group == nullptr) {
group = mpi::init(false);
bk_ = "mpi";

View File

@@ -5,6 +5,7 @@
#include <memory>
#include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core::distributed {

View File

@@ -13,10 +13,15 @@ class GroupImpl {
public:
virtual ~GroupImpl() {}
// Choose the stream this communication group can operate on
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
// Group operations
virtual int rank() = 0;
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
// Actual communication operations
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
@@ -25,6 +30,9 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0;
};
/* Define the MLX stream that the communication should happen in. */
Stream communication_stream(Group group, StreamOrDevice s = {});
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output, Stream stream);

Some files were not shown because too many files have changed in this diff Show More