Compare commits

..

93 Commits

Author SHA1 Message Date
Awni Hannun
bf7cd29970 version bump (#698) 2024-02-16 08:44:08 -08:00
Nripesh Niketan
a000d2288c feat: update black pre-commit hook to 24.2.0 (#696) 2024-02-16 06:01:59 -08:00
Mike Drob
165abf0e4c Auto-run PRs from contributors (#692) 2024-02-15 17:30:35 -08:00
Srimukh Sripada
818cda16bc Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
toji
85143fecdd improved error msg for invalid axis(mx.split) (#685)
* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8 Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32 Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Awni Hannun
1eb04aa23f Fix empty array construction in cpp (#684) 2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91 Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0

* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00
Vijay Krish
2fdc2462c3 Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Hinrik Snær Guðmundsson
be6e9d6a9f Fixed wording in extensions.rst (#678)
changed "learn how add" -> "learn how to add"
2024-02-13 08:39:02 -08:00
Gabrijel Boduljak
e54cbb7ba6 Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
Angelos Katharopoulos
40c108766b Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Mike Drob
4cc70290f7 PR Builder Workflow (#659) 2024-02-12 17:47:21 -08:00
Awni Hannun
74caa68d02 nit in readme (#675) 2024-02-12 12:25:04 -08:00
Awni Hannun
3756381358 Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6 quote file name (#670) 2024-02-11 10:33:30 -08:00
Nripesh Niketan
0dbc4c7547 feat: Update pre-commit-config.yaml (#667) 2024-02-11 06:08:20 -08:00
Vijay Krish
06072601ce Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Angelos Katharopoulos
11d2c8f7a1 Linux build for CI of other packages (#660) 2024-02-09 18:17:04 -08:00
Awni Hannun
7f3f8d8f8d Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185 Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Angelos Katharopoulos
221f8d3fc2 Bump the version to 0.2 (#656) 2024-02-08 11:27:12 -08:00
Awni Hannun
5c03efaf29 Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
LeonEricsson
7dccd42133 updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Angelos Katharopoulos
28eac18571 Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Noah Farr
5fd11c347d Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Aryan Gupta
ef73393a19 Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
Angelos Katharopoulos
ea406d5e33 CI change (#645)
* CI update

* Skip large binary test for now

* Upgrade pip

* Add proper env variable skipping

* Update the CI

* Fix workflow name

* Set the low memory flag for the tests

* Change build process

* Add pip upgrade

* Use a venv

* Add a missing env activate

* Add setuptools

* Add twine upload back

* Re-enable automatic release builds
2024-02-07 06:04:34 -08:00
Awni Hannun
146bd69470 Skip compile when transforming (#635)
* skip compile when transforming

* simplify message
2024-02-05 21:28:37 -08:00
Jagrit Digani
316ff490b3 Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Avikant Srivastava
31fea3758e feat: enhancement of the error message for mlx.core.mean (#608)
* add error message
2024-02-05 01:21:49 -08:00
Awni Hannun
e319383ef9 Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
Awni Hannun
5c3ac52dd7 fix test (#627) 2024-02-04 16:18:03 -08:00
David Koski
ebfd3618b0 fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Avikant Srivastava
11a9fd40f0 fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Daniel Strobusch
4fd2fb84a6 make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19 fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
minghuaw
16750f3c51 Fix typo in CMakeLists.txt (#616) 2024-02-03 05:59:26 -08:00
Awni Hannun
95b5fb8245 minor changes (#613) 2024-02-02 11:48:35 -08:00
AtomicVar
83f63f2184 Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
Awni Hannun
cb6156d35d Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Piotr Rybiec
506d43035c typo fix (#607) 2024-02-01 17:39:55 -08:00
Angelos Katharopoulos
36cff34701 Bump the version (#604) 2024-02-01 11:41:38 -08:00
Awni Hannun
e88e474fd1 Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00
David Koski
601c6d6aa8 Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365 Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
Vijay Krish
fcc5ac1c64 Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
nathan
bad67fec37 Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
199aebcf77 Change the variance computation (#319) 2024-01-30 19:28:56 -08:00
Angelos Katharopoulos
0de5988f92 Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5 Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Jagrit Digani
375446453e Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20 Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Awni Hannun
09b9275027 Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454 Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jacket
3f7aba8498 Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Angelos Katharopoulos
65d0b8df9f Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345 Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Angelos Katharopoulos
37d98ba6ff No gil eval (#565) 2024-01-26 22:03:52 -08:00
Awni Hannun
8993382aaa Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
Awni Hannun
07f35c9d8a Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002 Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9 Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
David Koski
874b739f3c Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
taher
077c1ee64a QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471 [Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Angelos Katharopoulos
87b7fa9ba2 Bump the version (#554) 2024-01-25 11:01:05 -08:00
Danilo Peixoto
624065c074 Fix package installation for CI (#521)
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-25 09:43:34 -08:00
Awni Hannun
f27ec5e097 More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
2024-01-24 09:58:33 -08:00
Awni Hannun
f30e63353a Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Juarez Bochi
4fe2fa2a64 GGUF: Avoid dequantization when format is compatible (#426)
* GGUF: Don't dequantize q4_1

* Fix weight order. First in low bits

* Add unpacking for q4_0

* Don't dequantize q8_0

* rebase quants and split file

* don't quantize every weight

* reapply patch

* error handling

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:43:57 -08:00
Hazem Essam
37fc9db82c Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137 Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
86e0c79467 remove stale benchmarks (#527) 2024-01-22 22:17:58 -08:00
Awni Hannun
98c37d3a22 use axes in tensordot (#525) 2024-01-22 21:17:00 -08:00
Sugato Ray
f326dd8334 Update README.md (#524)
Add conda install option in docs.
2024-01-22 20:53:54 -08:00
Jagrit Digani
6d3bee3364 Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Danilo Peixoto
ecb174ca9d Type annotations for mlx.core module (#512) 2024-01-21 12:53:12 -08:00
Awni Hannun
7a34e46677 Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Nripesh Niketan
92c22c1ea3 feat: Update isort version to 5.13.2 (#514) 2024-01-21 06:11:48 -08:00
Awni Hannun
d52383367a format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Awni Hannun
b207c2c86b Power VJP fix for 0 (#505) 2024-01-20 01:17:40 -08:00
185 changed files with 11332 additions and 3544 deletions

View File

@@ -1,5 +1,8 @@
version: 2.1 version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters: parameters:
nightly_build: nightly_build:
type: boolean type: boolean
@@ -7,6 +10,9 @@ parameters:
weekly_build: weekly_build:
type: boolean type: boolean
default: false default: false
test_release:
type: boolean
default: false
jobs: jobs:
linux_build_and_test: linux_build_and_test:
@@ -26,18 +32,23 @@ jobs:
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run: - run:
name: Build python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run: - run:
name: Run the python tests name: Generate package stubs
command: | command: |
python3 -m unittest discover python/tests python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable # TODO: Reenable when extension api becomes stable
# - run: # - run:
# name: Build example extension # name: Build example extension
@@ -52,169 +63,180 @@ jobs:
command: ./build/tests/tests command: ./build/tests/tests
mac_build_and_test: mac_build_and_test:
machine: true macos:
resource_class: ml-explore/m-builder xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
eval "$(conda shell.bash hook)" brew install python@3.9
rm -r $CONDA_PREFIX/envs/runner-env python3.9 -m venv env
conda create -y -n runner-env python=3.9 source env/bin/activate
conda activate runner-env pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
pip install unittest-xml-reporting pip install unittest-xml-reporting
- run: - run:
name: Build python package name: Install Python package
command: | command: |
eval "$(conda shell.bash hook)" source env/bin/activate
conda activate runner-env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
- run: - run:
name: Run the python tests name: Generate package stubs
command: | command: |
eval "$(conda shell.bash hook)" source env/bin/activate
conda activate runner-env python setup.py generate_stubs
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - run:
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable # TODO: Reenable when extension api becomes stable
# - run: # - run:
# name: Build example extension # name: Build example extension
# command: | # command: |
# eval "$(conda shell.bash hook)" # cd examples/extensions && python3.11 -m pip install .
# conda activate runner-env
# cd examples/extensions && python -m pip install .
- store_test_results: - store_test_results:
path: test-results path: test-results
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j mkdir -p build && cd build && cmake .. && make -j
- run: - run:
name: Run CPP tests name: Run CPP tests
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
build_release: build_release:
machine: true
resource_class: ml-explore/m-builder
parameters: parameters:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
macos_version: xcode_version:
type: string type: string
default: "14" default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
eval "$(conda shell.bash hook)" brew install python@<< parameters.python_version >>
rm -r $CONDA_PREFIX/envs/runner-env python<< parameters.python_version >> -m venv env
conda create -y -n runner-env python=<< parameters.python_version >> source env/bin/activate
conda activate runner-env pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install twine pip install twine
pip install build
- run: - run:
name: Build package name: Install Python package
command: | command: |
eval "$(conda shell.bash hook)" source env/bin/activate
conda activate runner-env DEV_RELEASE=1 \
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel pip install . -v
twine upload dist/* --repository mlx - run:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
build_dev_release: build_linux_test_release:
machine: true
resource_class: ml-explore/m-builder
parameters: parameters:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.9"
macos_version: extra_env:
type: string type: string
default: "14" default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps: steps:
- checkout - checkout
- run: - run:
name: Install dependencies name: Build wheel
command: | command: |
eval "$(conda shell.bash hook)" PYTHON=python<< parameters.python_version >>
rm -r $CONDA_PREFIX/envs/runner-env apt-get update
conda create -y -n runner-env python=<< parameters.python_version >> apt-get upgrade -y
conda activate runner-env DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install --upgrade pybind11[global] pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy pip install numpy
pip install twine pip install auditwheel
- run: pip install patchelf
name: Build package pip install build
command: | << parameters.extra_env >> \
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel pip install . -v
twine upload dist/* --repository mlx python setup.py generate_stubs
- store_artifacts: << parameters.extra_env >> \
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- store_artifacts: - store_artifacts:
path: dist/ path: wheelhouse/
workflows: workflows:
build_and_test: build_and_test:
when: when:
and: and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs: jobs:
- linux_build_and_test
- mac_build_and_test - mac_build_and_test
- linux_build_and_test
- build_release: - build_release:
filters: filters:
tags: tags:
@@ -224,20 +246,53 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"] xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]
nightly_build: nightly_build:
when: << pipeline.parameters.nightly_build >> when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs: jobs:
- build_package: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"] xcode_version: ["14.3.1", "15.2.0"]
weekly_build: weekly_build:
when: << pipeline.parameters.weekly_build >> when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs: jobs:
- build_dev_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"] xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -5,11 +5,11 @@ repos:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1 rev: 24.2.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
args: args:

View File

@@ -10,8 +10,8 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.10) set(MLX_VERSION 0.3.0)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
@@ -31,13 +31,13 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE}) if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR message(FATAL_ERROR
"Building for x86_64 on macOS is not supported." "Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build" " If you are on an Apple silicon system, check the build"
" documentation for possible fixes: " " documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING message(WARNING
"Building for x86_64 on macOS is not supported." "Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, " " If you are on an Apple silicon system, "
" make sure you are building for arm64.") " make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64") elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
@@ -75,7 +75,7 @@ elseif (MLX_BUILD_METAL)
COMMAND_ERROR_IS_FATAL ANY) COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2) if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
@@ -123,16 +123,27 @@ else()
/usr/include /usr/include
/usr/local/include /usr/local/include
$ENV{BLAS_HOME}/include) $ENV{BLAS_HOME}/include)
message(STATUS ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES}) target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
mlx mlx
PUBLIC PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include> $<INSTALL_INTERFACE:include>

View File

@@ -1,3 +1,4 @@
include CMakeLists.txt include CMakeLists.txt
recursive-include mlx/ * recursive-include mlx/ *
include python/src/* include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,8 +6,8 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon, brought to you MLX is an array framework for machine learning research on Apple silicon,
by Apple machine learning research. brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
@@ -68,10 +68,18 @@ in the documentation.
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run: MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
``` ```
pip install mlx pip install mlx
``` ```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#) [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source. for more information on building the C++ and Python APIs from source.

View File

@@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
quant_matmul = { quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2), "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4), "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8), "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
@@ -84,6 +87,15 @@ quant_matmul = {
"quant_matmul_128_8": partial( "quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8 _quant_matmul, transpose=False, group_size=128, bits=8
), ),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial( "quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2 _quant_matmul, transpose=True, group_size=64, bits=2
), ),

View File

@@ -80,10 +80,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter) _filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes: if args.mlx_dtypes:
compare_filtered = ( compare_filtered = lambda x: (
lambda x: compare_mlx_dtypes( compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x) if _filter(x)
else None else None
) )

View File

@@ -0,0 +1,53 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -1,198 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,118 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,199 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.mps
import torch.nn as nn
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -0,0 +1,56 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
def scatter(dst, x, idx):
dst[idx] = x
mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
def gather(dst, x, idx, device):
dst[idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -44,6 +44,13 @@ def time_matmul():
time_fn(mx.matmul, a, b) time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -101,6 +108,7 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_maximum()
time_exp() time_exp()
time_negative() time_negative()
time_logsumexp() time_logsumexp()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import time import time
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec") print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

1
docs/.gitignore vendored
View File

@@ -1,2 +1,3 @@
src/python/_autosummary*/ src/python/_autosummary*/
src/python/nn/_autosummary*/ src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@@ -1,19 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@@ -12,7 +12,7 @@ import mlx.core as mx
project = "MLX" project = "MLX"
copyright = "2023, MLX Contributors" copyright = "2023, MLX Contributors"
author = "MLX Contributors" author = "MLX Contributors"
version = ".".join(mx.__version__.split()[:-1]) version = ".".join(mx.__version__.split(".")[:3])
release = version release = version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
@@ -26,6 +26,7 @@ extensions = [
python_use_unqualified_type_names = True python_use_unqualified_type_names = True
autosummary_generate = True autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = { intersphinx_mapping = {
"https://docs.python.org/3": None, "https://docs.python.org/3": None,

View File

@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
You would really like the part of your applications that does this operation You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose ``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework. over the ins-and-outs of the MLX framework.
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python Binding to Python
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings We use PyBind11_ to build a Python API for the C++ library. Since bindings for
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc. components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
are already provided, adding our :meth:`axpby` becomes very simple! already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++ .. code-block:: C++
@@ -927,18 +927,18 @@ Results:
We see some modest improvements right away! We see some modest improvements right away!
This operation is now good to be used to build other operations, This operation is now good to be used to build other operations, in
in :class:`mlx.nn.Module` calls, and also as a part of graph :class:`mlx.nn.Module` calls, and also as a part of graph transformations like
transformations such as :meth:`grad` and :meth:`simplify`! :meth:`grad`!
Scripts Scripts
------- -------
.. admonition:: Download the code .. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_. The full example code is available in `mlx <code>`_.
.. code: `TODO_LINK/extensions`_ .. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc .. _Metal: https://developer.apple.com/documentation/metal?language=objc

View File

@@ -41,6 +41,7 @@ are the CPU and GPU.
usage/indexing usage/indexing
usage/saving_and_loading usage/saving_and_loading
usage/function_transforms usage/function_transforms
usage/compile
usage/numpy usage/numpy
usage/using_streams usage/using_streams

View File

@@ -9,9 +9,10 @@ Devices and Streams
:toctree: _autosummary :toctree: _autosummary
Device Device
Stream
default_device default_device
set_default_device set_default_device
Stream
default_stream default_stream
new_stream new_stream
set_default_stream set_default_stream
stream

View File

@@ -9,3 +9,4 @@ Linear Algebra
:toctree: _autosummary :toctree: _autosummary
norm norm
qr

View File

@@ -180,3 +180,4 @@ In detail:
nn/layers nn/layers
nn/functions nn/functions
nn/losses nn/losses
nn/init

View File

@@ -19,5 +19,6 @@ simple functions.
prelu prelu
relu relu
selu selu
softshrink
silu silu
step step

View File

@@ -0,0 +1,45 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -10,6 +10,8 @@ Layers
:template: nn-module-template.rst :template: nn-module-template.rst
ALiBi ALiBi
AvgPool1d
AvgPool2d
BatchNorm BatchNorm
Conv1d Conv1d
Conv2d Conv2d
@@ -22,6 +24,8 @@ Layers
InstanceNorm InstanceNorm
LayerNorm LayerNorm
Linear Linear
MaxPool1d
MaxPool2d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU
@@ -33,5 +37,6 @@ Layers
Sequential Sequential
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softshrink
Step Step
Transformer Transformer

View File

@@ -18,6 +18,7 @@ Loss Functions
kl_div_loss kl_div_loss
l1_loss l1_loss
log_cosh_loss log_cosh_loss
margin_ranking_loss
mse_loss mse_loss
nll_loss nll_loss
smooth_l1_loss smooth_l1_loss

View File

@@ -11,6 +11,7 @@ Module
:toctree: _autosummary :toctree: _autosummary
Module.training Module.training
Module.state
.. rubric:: Methods .. rubric:: Methods

View File

@@ -35,6 +35,8 @@ Operations
cos cos
cosh cosh
dequantize dequantize
diag
diagonal
divide divide
divmod divmod
equal equal

View File

@@ -29,19 +29,8 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state. # Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers .. toctree::
.. autosummary:: optimizers/optimizer
:toctree: _autosummary optimizers/common_optimizers
:template: optimizers-template.rst optimizers/schedulers
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -0,0 +1,20 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -0,0 +1,13 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay

View File

@@ -9,9 +9,11 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
compile
disable_compile
enable_compile
grad grad
value_and_grad value_and_grad
jvp jvp
vjp vjp
vmap vmap
simplify

430
docs/src/usage/compile.rst Normal file
View File

@@ -0,0 +1,430 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -5,9 +5,12 @@ Function Transforms
.. currentmodule:: mlx.core .. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation and MLX uses composable function transformations for automatic differentiation,
vectorization. The key idea behind composable function transformations is that vectorization, and compute graph optimizations. To see the complete list of
every transformation returns a function which can be further transformed. function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
Here is a simple example: Here is a simple example:
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives. getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any Any of the MLX function transformations can be composed in any order to any
depth. To see the complete list of function transformations check-out the depth. See the following sections for more information on :ref:`automatic
:ref:`API documentation <transforms>`. See the following sections for more differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
information on :ref:`automatic differentiaion <auto diff>` and For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
:ref:`automatic vectorization <vmap>`.
Automatic Differentiation Automatic Differentiation
------------------------- -------------------------

View File

@@ -20,7 +20,7 @@ Transforming Compute Graphs
Lazy evaluation let's us record a compute graph without actually doing any Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations like :func:`simplify`. :func:`vmap` and graph optimizations.
Currently, MLX does not compile and rerun compute graphs. They are all Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to generated dynamically. However, lazy evaluation makes it much easier to

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.24) cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX) project(mlx_sample_extensions LANGUAGES CXX)
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif() endif()

View File

@@ -3,9 +3,10 @@ target_sources(
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
@@ -20,7 +21,7 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else() else()
target_sources( target_sources(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
@@ -82,6 +82,13 @@ array::array(std::initializer_list<float> data)
init(data.begin()); init(data.begin());
} }
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array( array::array(
allocator::Buffer data, allocator::Buffer data,
@@ -97,11 +104,13 @@ void array::detach() {
s.array_desc_->inputs.clear(); s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
s.array_desc_->position = 0; s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr; s.array_desc_->primitive = nullptr;
} }
array_desc_->inputs.clear(); array_desc_->inputs.clear();
array_desc_->siblings.clear(); array_desc_->siblings.clear();
array_desc_->position = 0; array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
@@ -155,6 +164,14 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::move_shared_buffer(array other) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype) array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) { : shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape); std::tie(size, strides) = cum_prod(shape);
@@ -170,9 +187,11 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)), primitive(std::move(primitive)),
inputs(inputs) { inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape); std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) { for (auto& in : this->inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
} }
depth++;
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
@@ -185,9 +204,11 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)), primitive(std::move(primitive)),
inputs(std::move(inputs)) { inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape); std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) { for (auto& in : this->inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
} }
depth++;
} }
array::ArrayIterator::ArrayIterator(const array& arr, int idx) array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -41,6 +41,9 @@ class array {
/* Special case so empty lists default to float32. */ /* Special case so empty lists default to float32. */
array(std::initializer_list<float> data); array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T> template <typename T>
array( array(
std::initializer_list<T> data, std::initializer_list<T> data,
@@ -121,6 +124,9 @@ class array {
template <typename T> template <typename T>
T item(); T item();
template <typename T>
T item() const;
struct ArrayIterator { struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t; using difference_type = size_t;
@@ -240,6 +246,11 @@ class array {
return array_desc_->inputs; return array_desc_->inputs;
} }
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */ /** The array's siblings. */
const std::vector<array>& siblings() const { const std::vector<array>& siblings() const {
return array_desc_->siblings; return array_desc_->siblings;
@@ -262,6 +273,11 @@ class array {
return outputs; return outputs;
}; };
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */ /** Detach the array from the graph. */
void detach(); void detach();
@@ -282,6 +298,12 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; };
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
@@ -322,6 +344,8 @@ class array {
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) { void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_; array_desc_ = other.array_desc_;
} }
@@ -364,6 +388,9 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype); explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(
@@ -382,7 +409,7 @@ class array {
// The ArrayDesc contains the details of the materialized array including the // The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes // shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs // the primitive which knows how to compute the array's data from its inputs
// and a the list of array's inputs for the primitive. // and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr}; std::shared_ptr<ArrayDesc> array_desc_{nullptr};
}; };
@@ -433,6 +460,18 @@ T array::item() {
return *data<T>(); return *data<T>();
} }
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
return *data<T>();
}
template <typename It> template <typename It>
void array::init(It src) { void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype()))); set_data(allocator::malloc(size() * size_of(dtype())));

View File

@@ -46,6 +46,14 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1); size_t N = b.shape(-1);
size_t K = a.shape(-1); size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) { for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm( cblas_sgemm(
CblasRowMajor, CblasRowMajor,
@@ -89,6 +97,14 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1); size_t N = b.shape(-1);
size_t K = a.shape(-1); size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{ const BNNSLayerParametersBroadcastMatMul gemm_params{
@@ -201,4 +217,4 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_); return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -33,8 +33,12 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)
@@ -50,11 +54,15 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd) DEFAULT(LogicalAnd)
DEFAULT(LogicalOr) DEFAULT(LogicalOr)
DEFAULT(LogAddExp) DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual) DEFAULT(NotEqual)
DEFAULT(Pad) DEFAULT(Pad)
DEFAULT(Partition) DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits) DEFAULT(RandomBits)
DEFAULT(Reshape) DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round) DEFAULT(Round)
DEFAULT(Scatter) DEFAULT(Scatter)
DEFAULT(Sigmoid) DEFAULT(Sigmoid)
@@ -64,27 +72,16 @@ DEFAULT_MULTI(Split)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT(StopGradient) DEFAULT(StopGradient)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) { } else if (in.dtype() == int32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) { } else if (is_unsigned(in.dtype())) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@@ -137,12 +134,8 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size); vvacosf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -153,12 +146,8 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size); vvacoshf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -169,12 +158,8 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size); vvasinf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -185,12 +170,8 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size); vvasinhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -201,12 +182,8 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size); vvatanf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -217,12 +194,8 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size); vvatanhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -234,30 +207,23 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().contiguous) { if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible // Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) { if (in.dtype() == float32 && out.dtype() == uint32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfixu32( vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size()); in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return; return;
} else if (in.dtype() == float32 && out.dtype() == int32) { } else if (in.dtype() == float32 && out.dtype() == int32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size()); vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return; return;
} else if (in.dtype() == uint32 && out.dtype() == float32) { } else if (in.dtype() == uint32 && out.dtype() == float32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfltu32( vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size()); in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return; return;
} else if (in.dtype() == int32 && out.dtype() == float32) { } else if (in.dtype() == int32 && out.dtype() == float32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return; return;
} }
@@ -269,12 +235,8 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size); vvcosf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -285,12 +247,8 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size); vvcoshf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -335,55 +293,12 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) { void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); }); unary_fp(in, out, [](auto x) { return std::exp(x); });
@@ -410,12 +325,8 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) { switch (base_) {
case Base::e: case Base::e:
vvlogf( vvlogf(
@@ -439,12 +350,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (is_floating_point(out.dtype())) {
@@ -456,47 +363,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -526,13 +392,8 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return -x; }); unary(in, out, [](auto x) { return -x; });
} }
@@ -545,7 +406,13 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous && if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) { b.flags().row_contiguous) {
int size = a.size(); int size = a.size();
out.set_data(allocator::malloc_or_wait(out.nbytes())); if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size); vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -587,12 +454,8 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size); vvsinf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -603,12 +466,8 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size); vvsinhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -619,12 +478,8 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return x * x; }); unary(in, out, [](auto x) { return x * x; });
@@ -635,12 +490,8 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) { if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size); vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else { } else {
@@ -695,12 +546,8 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size); vvtanf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@@ -711,12 +558,8 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size); vvtanhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);

View File

@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto check_input = [](array x) { auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) { bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -3,6 +3,7 @@ target_sources(
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
@@ -10,10 +11,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
) )

View File

@@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
struct RemainderFn { struct RemainderFn {
template <typename T> template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()( std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator, T numerator,
T denominator) { T denominator) {
return std::fmod(numerator, denominator); return numerator % denominator;
} }
template <typename T> template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()( std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator, T numerator,
T denominator) { T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator; return numerator % denominator;
} }
}; };
@@ -233,14 +251,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
} }
void Minimum::eval(const std::vector<array>& inputs, array& out) { void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x < y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
} }
void Multiply::eval(const std::vector<array>& inputs, array& out) { void Multiply::eval(const std::vector<array>& inputs, array& out) {

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@@ -40,29 +39,83 @@ void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt) { BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) { switch (bopt) {
case ScalarScalar: case ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case ScalarVector: case ScalarVector:
out.set_data( if (b.is_donatable() && b.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(b.data_size() * out.itemsize()), if (donate_with_move) {
b.data_size(), out.move_shared_buffer(b);
b.strides(), } else {
b.flags()); out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
break; break;
case VectorScalar: case VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case VectorVector: case VectorVector:
out.set_data( if (a.is_donatable() && a.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(a.data_size() * out.itemsize()), if (donate_with_move) {
a.data_size(), out.move_shared_buffer(a);
a.strides(), } else {
a.flags()); out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break; break;
case General: case General:
out.set_data(allocator::malloc_or_wait(out.nbytes())); if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break; break;
} }
} }

View File

@@ -0,0 +1,59 @@
// Copyright © 2023-2024 Apple Inc.
#include <queue>
#include "mlx/primitives.h"
namespace mlx::core {
// Build the real tape
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
const std::vector<array>& trace_tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
std::queue<array> tape;
for (auto& a : trace_tape) {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
tape.push(
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
trace_to_real.insert({a.id(), tape.back()});
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return {tape, outputs};
}
void Compiled::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the a real tape from the tracers
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
// Run the tape
while (!tape.empty()) {
auto a = std::move(tape.front());
tape.pop();
auto outputs = a.outputs();
a.primitive().eval_cpu(a.inputs(), outputs);
a.detach();
}
// Copy results into outputs
for (int o = 0; o < real_outputs.size(); ++o) {
outputs[o].copy_shared_buffer(real_outputs[o]);
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,7 @@
#include <cassert> #include <cassert>
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h> #include <cblas.h>
#endif #endif

View File

@@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output // Allocate the output
switch (ctype) { switch (ctype) {
case CopyType::Vector: case CopyType::Vector:
dst.set_data( if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
allocator::malloc_or_wait(src.data_size() * dst.itemsize()), dst.copy_shared_buffer(src);
src.data_size(), } else {
src.strides(), auto size = src.data_size();
src.flags()); dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break; break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::General: case CopyType::General:

View File

@@ -1,11 +1,13 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h> #include <cblas.h>
#endif #endif
#include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@@ -39,12 +41,16 @@ DEFAULT(ArgSort)
DEFAULT(AsType) DEFAULT(AsType)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Convolution) DEFAULT(Convolution)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT(Cos) DEFAULT(Cos)
DEFAULT(Cosh) DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(Remainder) DEFAULT(Remainder)
DEFAULT(Equal) DEFAULT(Equal)
@@ -74,6 +80,7 @@ DEFAULT(NotEqual)
DEFAULT(Pad) DEFAULT(Pad)
DEFAULT(Partition) DEFAULT(Partition)
DEFAULT(Power) DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul) DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits) DEFAULT(RandomBits)
DEFAULT(Reduce) DEFAULT(Reduce)
@@ -96,7 +103,6 @@ DEFAULT(Subtract)
DEFAULT(Tan) DEFAULT(Tan)
DEFAULT(Tanh) DEFAULT(Tanh)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
namespace { namespace {
@@ -126,6 +132,13 @@ inline void matmul_common_general(
size_t M = a.shape(-2); size_t M = a.shape(-2);
size_t N = b.shape(-1); size_t N = b.shape(-1);
size_t K = a.shape(-1); size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) { for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm( cblas_sgemm(

View File

@@ -232,22 +232,38 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) { void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { switch (out.dtype()) {
case float32: case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); }); unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break; break;
case float16: case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) { unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x))); return static_cast<float16_t>(std::erf(static_cast<float>(x)));
}); });
break; break;
case bfloat16: case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) { unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x))); return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
}); });
@@ -264,17 +280,14 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { switch (out.dtype()) {
case float32: case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); }); unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break; break;
case float16: case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) { unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x))); return static_cast<float16_t>(erfinv(static_cast<float>(x)));
}); });
break; break;
case bfloat16: case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) { unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x))); return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
}); });

153
mlx/backend/common/qrf.cpp Normal file
View File

@@ -0,0 +1,153 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -119,6 +118,12 @@ void _qmm_dispatch_typed(
switch (bits) { switch (bits) {
case 2: { case 2: {
switch (group_size) { switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64: case 64:
if (transposed_w) { if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
@@ -135,6 +140,12 @@ void _qmm_dispatch_typed(
} }
case 4: { case 4: {
switch (group_size) { switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64: case 64:
if (transposed_w) { if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K); return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
@@ -151,6 +162,12 @@ void _qmm_dispatch_typed(
} }
case 8: { case 8: {
switch (group_size) { switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64: case 64:
if (transposed_w) { if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K); return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);

View File

@@ -0,0 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto check_input = [](array x) { auto check_input = [](array x) {
if (x.strides().back() == 1) { bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -64,15 +64,24 @@ struct RoundOp {
} }
}; };
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
}
template <typename T, typename Op> template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) { void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
if (a.flags().contiguous) { if (a.flags().contiguous) {
out.set_data( set_unary_output_data(a, out);
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
T* dst = out.data<T>(); T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) { for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]); dst[i] = op(a_ptr[i]);

View File

@@ -1,7 +1,28 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${CMAKE_SOURCE_DIR}
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
@@ -11,10 +32,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
) )
if (NOT MLX_METAL_PATH) if (NOT MLX_METAL_PATH)

View File

@@ -0,0 +1,484 @@
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
}
inline auto get_type_string(Dtype d) {
switch (d) {
case float32:
return "float";
case float16:
return "half";
case bfloat16:
return "bfloat16_t";
case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
default: {
std::ostringstream msg;
msg << "Unsupported compilation type " << d;
throw std::runtime_error(msg.str());
}
}
}
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
return print_float_constant<float>(os, x);
case float16:
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
return print_int_constant<int32_t>(os, x);
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:
return print_int_constant<uint32_t>(os, x);
case uint64:
return print_int_constant<uint64_t>(os, x);
case bool_:
os << std::boolalpha << x.item<bool>();
return;
default:
throw std::runtime_error("Unsupported constant type");
}
}
inline std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
std::ostringstream os;
std::ostringstream constant_hasher;
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
a.primitive().print(os);
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << ((x.size() == 1) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim,
bool dynamic_dims) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
// For scalar we shouldn't do the indexing things, just read at 0
auto is_scalar = [](const array& x) { return x.size() == 1; };
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
// Add the input arguments
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
}
os << "];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
// Finish the kernel
os << "}" << std::endl;
if (cnt > 31) {
std::ostringstream msg;
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
<< "primitive which exhausted the available argument buffers for "
<< "the kernel. Please file an issue with the function that results "
<< "in this error. The name of the kernel is '" << kernel_name << "'";
throw std::runtime_error(msg.str());
}
}
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i),
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
}
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
kernel_source_ = kernel.str();
lib = d.get_library(kernel_lib_, kernel_source_);
}
// Allocate space for the outputs
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
x.size() > 1) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<std::vector<size_t>> initial_strides;
initial_strides.push_back(outputs[0].strides());
std::vector<int> shape;
std::vector<std::vector<size_t>> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
// Skip scalar inputs.
if (x.size() <= 1) {
continue;
}
// Broadcast the inputs to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides);
}
// Get the kernel from the lib
int ndim = shape.size();
bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
if (dynamic) {
kernel_name += "dynamic";
} else {
kernel_name += std::to_string(shape.size());
}
}
auto kernel = d.get_kernel(kernel_name, lib);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && x.size() > 1) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
stride_idx++;
}
}
// Put the outputs in
for (auto& x : outputs) {
set_array_buffer(compute_encoder, x, cnt++);
}
// Put the output shape and strides in
if (!contiguous) {
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
}
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].size();
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,9 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <iostream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>

View File

@@ -12,11 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
out.set_data( if (in.is_donatable() && in.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(in.data_size() * out.itemsize()), out.move_shared_buffer(in);
in.data_size(), } else {
in.strides(), out.set_data(
in.flags()); allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
@@ -67,7 +71,8 @@ void copy_gpu_inplace(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0); bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-24 Apple Inc.
#include <dlfcn.h> #include <dlfcn.h>
#include <cstdlib> #include <cstdlib>
@@ -26,7 +26,8 @@ static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() { auto load_device() {
auto devices = MTL::CopyAllDevices(); auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0)); auto device = static_cast<MTL::Device*>(devices->object(0))
?: MTL::CreateSystemDefaultDevice();
if (!device) { if (!device) {
throw std::runtime_error("Failed to load device"); throw std::runtime_error("Failed to load device");
} }
@@ -214,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
return eit->second; return eit->second;
} }
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library( void Device::register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path) { const std::string& lib_path) {
@@ -242,37 +234,127 @@ void Device::register_library(
} }
} }
MTL::ComputePipelineState* Device::get_kernel( MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib // Search for cached metal lib
MTL::Library* mtl_lib; MTL::Library* mtl_lib;
if (auto it = library_map_.find(name); it != library_map_.end()) { if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second; mtl_lib = it->second;
} else { // Look for metallib alongside library } else { // Look for metallib alongside library
register_library(lib_name); register_library(lib_name);
mtl_lib = library_map_[lib_name]; mtl_lib = library_map_[lib_name];
} }
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Function* Device::get_function_(
const std::string& name,
MTL::Library* mtl_lib) {
// Pull kernel from library // Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding); auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name); auto mtl_function = mtl_lib->newFunction(ns_name);
return mtl_function;
}
MTL::Function* Device::get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib) {
if (func_consts.empty() && (specialized_name == name)) {
return get_function_(name, mtl_lib);
}
// Prepare function constants
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
for (auto [value, type, index] : func_consts) {
mtl_func_consts->setConstantValue(value, type, index);
}
// Prepare function desc
auto desc = MTL::FunctionDescriptor::functionDescriptor();
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
desc->setSpecializedName(
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
desc->setConstantValues(mtl_func_consts);
// Pull kernel from library
NS::Error* error = nullptr;
auto mtl_function = mtl_lib->newFunction(desc, &error);
// Throw error if unable to build metal function
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load function " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function) {
// Compile kernel to compute pipeline // Compile kernel to compute pipeline
NS::Error* error = nullptr; NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel; MTL::ComputePipelineState* kernel;
if (mtl_function) { if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error); kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
} }
// Throw error if unable to compile metal function
if (!mtl_function || !kernel) { if (!mtl_function || !kernel) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n"; msg << "[metal::Device] Unable to load kernel " << name << "\n";
@@ -282,11 +364,175 @@ MTL::ComputePipelineState* Device::get_kernel(
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
// Add kernel to cache
kernel_map_.insert({name, kernel});
return kernel; return kernel;
} }
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions) {
// Check inputs
if (!linked_functions) {
return get_kernel_(name, mtl_function);
}
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
throw std::runtime_error(msg.str());
}
// Prepare compute pipeline state descriptor
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
desc->setComputeFunction(mtl_function);
desc->setLinkedFunctions(linked_functions);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
auto kernel = device_->newComputePipelineState(
desc, MTL::PipelineOptionNone, nullptr, &error);
// Throw error if unable to compile metal function
if (!kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
return nullptr;
}
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
std::vector<NS::Object*> objs(funcs.size());
for (int i = 0; i < funcs.size(); i++) {
objs[i] = funcs[i];
}
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
lfuncs->setPrivateFunctions(funcs_arr);
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
}
Device& device(mlx::core::Device) { Device& device(mlx::core::Device) {
static Device metal_device; static Device metal_device;
return metal_device; return metal_device;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-24 Apple Inc.
#pragma once #pragma once
@@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
return mtllib_path; return mtllib_path;
} }
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
class Device { class Device {
public: public:
Device(); Device();
@@ -59,14 +62,73 @@ class Device {
const std::function<std::string(const std::string&)>& lib_path_func = const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path); get_colocated_mtllib_path);
MTL::ComputePipelineState* get_kernel( MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name, const std::string& name,
const std::string& lib_name = "mlx"); const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private: private:
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
MTL::Function* get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib);
MTL::LinkedFunctions* get_linked_functions_(
const std::vector<MTL::Function*>& funcs);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_; std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
@@ -39,114 +39,75 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname; std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
}
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1; size_t slice_size = 1;
for (auto s : slice_sizes_) { for (auto s : slice_sizes_) {
slice_size *= s; slice_size *= s;
} }
size_t ndim = src.ndim(); // Launch 2D grid of threads: indices x slice
size_t nthreads = out.size(); size_t dim0 = out.size() / slice_size;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t dim1 = slice_size;
if (thread_group_size > nthreads) { auto group_dims = get_block_dims(dim0, dim1, 1);
thread_group_size = nthreads; MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); // Collect all idx shapes and strides into one place
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy( idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(), inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(), inputs[i + 1].shape().end());
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy( idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(), inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(), inputs[i + 1].strides().end());
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
} }
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
// Set all the buffers // Set all the buffers
set_array_buffer(compute_encoder, src, 0); set_array_buffer(compute_encoder, src, 0);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1); set_array_buffer(compute_encoder, out, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -211,82 +172,35 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the // Collect all idx shapes and strides into one place
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); std::vector<int> idx_shapes;
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
std::copy( idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(), inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(), inputs[i + 1].shape().end());
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy( idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(), inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(), inputs[i + 1].strides().end());
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
} }
// Allocate the argument buffer // Set all the buffers
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
// Register data with the encoder // Set update info
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
size_t upd_ndim = upd.ndim(); size_t upd_ndim = upd.ndim();
size_t upd_size = 1; size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) { for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
if (upd_ndim == 0) { if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
@@ -301,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim(); size_t out_ndim = out.ndim();
if (out_ndim == 0) { if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
@@ -316,16 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
compute_encoder->dispatchThreads(grid_dims, group_dims); // Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Cleanup temporaries // Set index buffers
arg_enc->release(); for (int i = 1; i < nidx + 1; ++i) {
d.get_command_buffer(s.index)->addCompletedHandler( set_array_buffer(compute_encoder, inputs[i], 20 + i);
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { }
allocator::free(arg_buf);
allocator::free(idx_shapes_buf); // Launch grid
allocator::free(idx_strides_buf); MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
}); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -6,6 +6,7 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h
) )
@@ -22,11 +23,13 @@ set(
"quantized" "quantized"
"random" "random"
"reduce" "reduce"
"rope"
"scan" "scan"
"softmax" "softmax"
"sort" "sort"
"unary" "unary"
"indexing" "gather"
"scatter"
) )
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)

View File

@@ -63,18 +63,6 @@ struct ArgMax {
} }
}; };
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U> template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) { IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>( return IndexValPair<U>(

View File

@@ -0,0 +1,231 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};

View File

@@ -1,145 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_integer> #include "mlx/backend/metal/kernels/binary.h"
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Add {
template <typename T> T operator()(T x, T y) { return x + y; }
};
struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
struct NaNEqual {
template <typename T> bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real)
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T> bool operator()(T x, T y) { return x > y; }
};
struct GreaterEqual {
template <typename T> bool operator()(T x, T y) { return x >= y; }
};
struct Less {
template <typename T> bool operator()(T x, T y) { return x < y; }
};
struct LessEqual {
template <typename T> bool operator()(T x, T y) { return x <= y; }
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf) ? maxval :
(maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x >= y ? x : y;
}
};
struct Minimum {
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x <= y ? x : y;
}
};
struct Multiply {
template <typename T> T operator()(T x, T y) { return x * y; }
};
struct NotEqual {
template <typename T> bool operator()(T x, T y) { return x != y; }
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) { return x && y; };
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) { return x || y; };
};
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s( [[kernel]] void binary_op_s2s(
@@ -389,4 +250,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@@ -14,10 +14,29 @@ struct FloorDivide {
}; };
struct Remainder { struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; } template <typename T>
template <> float operator()(float x, float y) { return fmod(x, y); } metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
template <> half operator()(half x, half y) { return fmod(x, y); } return x % y;
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); } }
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
}; };
template <typename T, typename U, typename Op1, typename Op2> template <typename T, typename U, typename Op1, typename Op2>

View File

@@ -0,0 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/unary.h"

View File

@@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
constexpr complex64_t operator%(complex64_t a, complex64_t b) { constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real)); auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag)); auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag}; return {real, imag};
} }

View File

@@ -0,0 +1,187 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T *src [[buffer(0)]],
device T *out [[buffer(1)]],
const constant int *src_shape [[buffer(2)]],
const constant size_t *src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int *slice_sizes [[buffer(5)]],
const constant int *axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T *src [[buffer(0)]], \
device T *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_gather(0)
make_gather(1)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
const device src_t *src [[buffer(0)]], \
device src_t *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)

View File

@@ -121,8 +121,18 @@ struct GEMVKernel {
for(int tm = 0; tm < TM; tm++) { for(int tm = 0; tm < TM; tm++) {
// Load for the row // Load for the row
for(int tn = 0; tn < TN; tn++) { if(bn + TN <= in_vec_size) {
inter[tn] = mat[tm * in_vec_size + bn + tn]; #pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
}
} }
// Accumulate results // Accumulate results

View File

@@ -0,0 +1,54 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {
return (idx < 0) ? idx + size : idx;
}
}
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@@ -1,254 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_texture>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<device IdxT*, NIDX> buffers [[id(0)]];
device int* shapes [[id(NIDX + 1)]];
device size_t* strides [[id(NIDX + 2)]];
const int ndim [[id(NIDX + 3)]];
};
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
[[kernel]] void scatter(
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const device int *upd_shape [[buffer(3)]],
const device size_t *upd_strides [[buffer(4)]],
const device size_t& upd_ndim [[buffer(5)]],
const device size_t& upd_size [[buffer(6)]],
const device int *out_shape [[buffer(7)]],
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
template [[host_name("scatter" name "_" #nindex)]] \
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
const device type *updates [[buffer(1)]], \
device mlx_atomic<type> *out [[buffer(2)]], \
const device int *upd_shape [[buffer(3)]], \
const device size_t *upd_strides [[buffer(4)]], \
const device size_t& upd_ndim [[buffer(5)]], \
const device size_t& upd_size [[buffer(6)]], \
const device int *out_shape [[buffer(7)]], \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -15,6 +15,14 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits> template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
@@ -31,21 +39,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE"); static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
(void)lid;
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits; constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread; constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / group_size; constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block]; typedef typename AccT<T>::acc_t U;
threadgroup T biases_block[BM * groups_per_block]; threadgroup U scales_block[BM * groups_per_block];
threadgroup T x_block[colgroup]; threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
thread uint32_t w_local; thread uint32_t w_local;
thread T result = 0; thread U result = 0;
thread T scale = 1; thread U scale = 1;
thread T bias = 0; thread U bias = 0;
thread T x_thread[el_per_thread]; thread U x_thread[el_per_thread];
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread; const int in_vec_size_w = in_vec_size / el_per_thread;
@@ -57,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
x += tid.z * in_vec_size; x += tid.z * in_vec_size;
y += tid.z * out_vec_size; y += tid.z * out_vec_size;
if (out_row >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup // Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) { for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory // Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid < simdgroups_fetching_vec) { if (simd_gid == 0) {
x_block[lid] = x[lid + i]; #pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
}
} }
if (simd_lid == 0) { if (simd_lid == 0) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@@ -90,7 +107,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work. // Do all the work.
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) { for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k]; result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits; w_local >>= bits;
} }
} }
@@ -100,7 +117,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Store the result // Store the result
if (simd_lid == 0) { if (simd_lid == 0) {
y[out_row] = result; y[out_row] = static_cast<T>(result);
} }
} }
@@ -129,23 +146,25 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int colgroup = BN * el_per_int; constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size; constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block]; typedef typename AccT<T>::acc_t U;
threadgroup T biases_block[BM * groups_per_block]; threadgroup U scales_block[BM * groups_per_block];
threadgroup T x_block[BM]; threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
thread uint32_t w_local; thread uint32_t w_local;
thread T result[el_per_int] = {0}; thread U result[el_per_int] = {0};
thread T scale = 1; thread U scale = 1;
thread T bias = 0; thread U bias = 0;
thread T x_local = 0; thread U x_local = 0;
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int; const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
int out_col = (tid.y * BN + simd_gid) * el_per_int; int out_col_start = tid.y * (BN * el_per_int);
int out_col = out_col_start + simd_gid * el_per_int;
w += out_col / el_per_int; w += out_col / el_per_int;
scales += out_col / group_size; scales += out_col_start / group_size;
biases += out_col / group_size; biases += out_col_start / group_size;
x += tid.z * in_vec_size; x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col; y += tid.z * out_vec_size + out_col;
@@ -155,26 +174,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup // Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) { for (int i=0; i<in_vec_size; i+=BM) {
int offset = simd_lid + i; int offset_lid = simd_lid + i;
bool thread_in_bounds = offset < in_vec_size; int offset_gid = simd_gid + i;
bool thread_in_bounds = offset_lid < in_vec_size;
bool group_in_bounds = offset_gid < in_vec_size;
// Load the vec to shared memory // Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) { if (simd_gid == 0) {
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0; x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
} }
// Load the scales and biases to shared memory // Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) { if (simd_lid < groups_per_block && group_in_bounds) {
#pragma clang loop unroll(full) scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
for (int j=0; j<groups_per_block; j++) { biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -184,12 +199,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size]; bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements // Load the matrix elements
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0; w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
// Do all the work. // Do all the work.
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) { for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local; result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
w_local >>= bits; w_local >>= bits;
} }
} }
@@ -204,7 +219,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
if (simd_lid == 0) { if (simd_lid == 0) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) { for (int k=0; k<el_per_int; k++) {
y[k] = result[k]; y[k] = static_cast<T>(result[k]);
} }
} }
} }
@@ -243,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>; using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>; using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
threadgroup T scales_block[BN * groups_per_block]; threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block]; threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK]; threadgroup T Xs[BM * BK];
@@ -306,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * K_w + offset_col; const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int; threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_col + offset_col < N) { if (y_row + offset_row < N) {
uint32_t wi = *w_local; uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)]; T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -421,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
for (int k=0; k<K; k += BK) { for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile // Load the x tile
if (num_els < BM) { short num_k = min(BK, K - k);
loader_x.load_safe(short2(BK, num_els)); if (num_els < BM || num_k < BK) {
loader_x.load_safe(short2(num_k, num_els));
} else { } else {
loader_x.load_unsafe(); loader_x.load_unsafe();
} }
@@ -450,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile // Load the w tile
{ {
if (k + BK >= K) { if (num_k < BK) {
for (int wo=0; wo<w_els_per_thread; wo++) { for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo; int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int); int offset_row = offset / (BN / el_per_int);
@@ -543,6 +558,9 @@ instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2) instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4) instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8) instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_qvm(name, itype, group_size, bits) \ #define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -570,6 +588,9 @@ instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2) instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4) instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8) instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \ #define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
@@ -601,6 +622,9 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2) instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4) instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8) instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \ #define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -629,3 +653,6 @@ instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2) instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4) instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8) instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)

View File

@@ -24,11 +24,59 @@ template <typename T, typename Op>
device otype *out [[buffer(1)]], \ device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]); uint tid [[thread_position_in_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// All reduce // All reduce
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce( [[kernel]] void all_reduce(
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
@@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op; Op op;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
U total_val = Op::init;
in += gid * N_READS; U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group // Reduction within simd group
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val; local_vals[simd_group_id] = total_val;
} }
// Reduction within thread group // Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init; total_val = lid < simd_per_group ? local_vals[lid] : op.init;
@@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
} }
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \ template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \ [[kernel]] void all_reduce<itype, otype, op>( \
@@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Row atomics // Row atomics
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general( [[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]], const device T *in [[buffer(0)]],
@@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op; Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size]; threadgroup U local_vals[simd_size];
// Loop over the reduction size within thread group U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize.x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val); total_val = op.simd_reduce(total_val);
@@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
} }
} }
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \ template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \ [[kernel]] void row_reduce_general<itype, otype, op>( \
@@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Column reduce // Column reduce
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline void _contiguous_strided_reduce( inline U _contiguous_strided_reduce(
const device T *in, const device T *in,
device mlx_atomic<U> *out, threadgroup U *local_data,
threadgroup U *local_data, uint in_idx,
uint in_idx, uint reduction_size,
uint out_idx, uint reduction_stride,
uint reduction_size, uint2 tid,
uint reduction_stride, uint2 lid,
uint2 tid,
uint2 lid,
uint2 lsize) { uint2 lsize) {
Op op; Op op;
T local_vals[N_READS]; U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]); uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
} }
local_data[lsize.y * lid.x + lid.y] = total_val; local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) { if(lid.y == 0) {
U val = op.init; // Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) { for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]); val = op(val, local_data[lsize.y * lid.x + i]);
} }
op.atomic_update(out, val, out_idx);
} }
return val;
} }
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
@@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
device mlx_atomic<U> *out [[buffer(1)]], device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]], const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]], const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]], const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]], const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]], threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x; auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc( auto in_idx = elem_to_loc(
@@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
ndim ndim
); );
Op op;
if(out_idx < out_size) { if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>( U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in, in,
out, local_data,
local_data, in_idx,
in_idx, reduction_size,
out_idx, reduction_stride,
reduction_size, tid.xy,
reduction_stride, lid.xy,
tid.xy, lsize.xy);
lid.xy,
lsize.xy); // Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
} }
} }
@@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]); uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Instantiations // Instantiations
@@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op) instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \ #define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \ instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>) instantiate_reduce(name ##tname, type, type, op<type>)
@@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum) instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum) instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod) instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod) instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod) instantiate_same_reduce(prod, uint32, uint32_t, Prod)
@@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod) instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod) instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum) instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod) instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
@@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min) instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min) instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max) instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max) instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max) instantiate_same_reduce(max_, uint32, uint32_t, Max)
@@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max) instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@@ -0,0 +1,68 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)

View File

@@ -0,0 +1,194 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int *upd_shape [[buffer(3)]],
const constant size_t *upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int *out_shape [[buffer(7)]],
const constant size_t *out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_scatter(0)
make_scatter(1)
make_scatter(2)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -89,20 +89,9 @@ struct GEMMKernel {
// Appease the compiler // Appease the compiler
(void)l; (void)l;
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
if (!M_aligned) { short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_a.set_mask(tile_dims_A, mask_A);
}
if (!N_aligned) {
short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
loader_b.set_mask(tile_dims_B, mask_B);
}
for (int k = 0; k < gemm_k_iterations; k++) { for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -110,13 +99,13 @@ struct GEMMKernel {
if (M_aligned) { if (M_aligned) {
loader_a.load_unsafe(); loader_a.load_unsafe();
} else { } else {
loader_a.load_safe(mask_A); loader_a.load_safe(tile_dims_A);
} }
if (N_aligned) { if (N_aligned) {
loader_b.load_unsafe(); loader_b.load_unsafe();
} else { } else {
loader_b.load_safe(mask_B); loader_b.load_safe(tile_dims_B);
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -137,11 +126,8 @@ struct GEMMKernel {
short2 tile_dims_B_last = short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.set_mask(tile_dims_A_last, mask_A); loader_a.load_safe(tile_dims_A_last);
loader_b.set_mask(tile_dims_B_last, mask_B); loader_b.load_safe(tile_dims_B_last);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -218,14 +204,8 @@ struct GEMMKernel {
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; loader_a.load_safe(tile_dims_A);
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; loader_b.load_safe(tile_dims_B);
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@@ -112,14 +112,8 @@ template <typename T,
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; loader_a.load_safe(tile_dims_A);
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; loader_b.load_safe(tile_dims_B);
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@@ -67,24 +67,22 @@ struct BlockLoader {
} }
} }
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void set_mask(
thread const short2& src_tile_dims,
thread bool mask[n_rows][vec_size]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
mask[i][j] =
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */ /* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const { METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi); src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks // Use fast thread memory for bound checks
bool tmp_idx[vec_size]; bool tmp_idx[vec_size];
T tmp_val[vec_size]; T tmp_val[vec_size];
@@ -117,39 +115,6 @@ struct BlockLoader {
} }
} }
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
simdgroup_barrier(mem_flags::mem_none);
// Use fast thread memory for bound checks
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
}
simdgroup_barrier(mem_flags::mem_none);
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
}
simdgroup_barrier(mem_flags::mem_none);
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */ /* Iteration helper */
METAL_FUNC void next() { METAL_FUNC void next() {
src += tile_stride; src += tile_stride;

View File

@@ -0,0 +1,376 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/utils.h"
struct Abs {
template <typename T>
T operator()(T x) {
return metal::abs(x);
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return metal::precise::acos(x);
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return metal::precise::acosh(x);
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return metal::precise::asin(x);
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return metal::precise::asinh(x);
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return metal::precise::atan(x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return metal::precise::atanh(x);
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Cos {
template <typename T>
T operator()(T x) {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(erf(static_cast<float>(x)));
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(erfinv(static_cast<float>(x)));
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Log {
template <typename T>
T operator()(T x) {
return metal::precise::log(x);
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return metal::precise::log2(x);
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return metal::precise::log10(x);
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
};
struct Round {
template <typename T>
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
};
struct Sin {
template <typename T>
T operator()(T x) {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return metal::precise::sqrt(x);
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return metal::precise::rsqrt(x);
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};

View File

@@ -1,223 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_integer> #include "mlx/backend/metal/kernels/unary.h"
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Abs {
template <typename T> T operator()(T x) { return metal::abs(x); };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
template <> complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
};
struct ArcCosh {
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
};
struct ArcSin {
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
};
struct ArcSinh {
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
};
struct ArcTan {
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
};
struct ArcTanh {
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
};
struct Ceil {
template <typename T> T operator()(T x) { return metal::ceil(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Cos {
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Cosh {
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Erf {
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
};
struct ErfInv {
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
};
struct Exp {
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
template <> complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Floor {
template <typename T> T operator()(T x) { return metal::floor(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Log {
template <typename T> T operator()(T x) { return metal::precise::log(x); };
};
struct Log2 {
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
};
struct Log10 {
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
};
struct Log1p {
template <typename T> T operator()(T x) { return log1p(x); };
};
struct LogicalNot {
template <typename T> T operator()(T x) { return !x; };
};
struct Negative {
template <typename T> T operator()(T x) { return -x; };
};
struct Round {
template <typename T> T operator()(T x) { return metal::rint(x); };
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
template <> uint32_t operator()(uint32_t x) { return x != 0; };
};
struct Sin {
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Sinh {
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Square {
template <typename T> T operator()(T x) { return x * x; };
};
struct Sqrt {
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
};
struct Rsqrt {
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
};
struct Tan {
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {
(tan_a - tanh_b * t1) / denom,
(tanh_b + tan_a * t1) / denom
};
};
};
struct Tanh {
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {
(tanh_a + tan_b * t1) / denom,
(tan_b - tanh_a * t1) / denom
};
};
};
template <typename T, typename Op> template <typename T, typename Op>
[[kernel]] void unary_op_v( [[kernel]] void unary_op_v(

View File

@@ -12,10 +12,10 @@
template <typename U> template <typename U>
struct Limits { struct Limits {
static const constant U max; static const constant U max = metal::numeric_limits<U>::max();
static const constant U min; static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max; static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min; static const constant U finite_min = metal::numeric_limits<U>::min();
}; };
#define instantiate_default_limit(type) \ #define instantiate_default_limit(type) \
@@ -71,7 +71,7 @@ inline size_t elem_to_loc(
device const size_t* strides, device const size_t* strides,
int ndim) { int ndim) {
size_t loc = 0; size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }
@@ -84,7 +84,7 @@ inline size_t elem_to_loc(
constant const size_t* strides, constant const size_t* strides,
int ndim) { int ndim) {
size_t loc = 0; size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
} }
@@ -235,12 +235,42 @@ inline size_t ceildiv(size_t N, size_t M) {
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) { inline float log1p(float x) {
float xp1 = 1.0f + x; float xp1 = 1.0f + x;
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f)); if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
} }
inline bfloat16_t log1p(bfloat16_t x) { inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x); float xp1 = 1.0f + static_cast<float>(x);
bfloat16_t ret = if (xp1 == Limits<float>::max) {
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); return Limits<bfloat16_t>::max;
return ret; }
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
} }

View File

@@ -0,0 +1,28 @@
#!/bin/bash
#
# This script generates a C++ function that provides the Metal unary and binary
# ops at runtime for use with kernel generation.
#
# Copyright © 2023-24 Apple Inc.
OUTPUT_FILE=$1
CC=$2
SRCDIR=$3
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
// Copyright © 2023-24 Apple Inc.
namespace mlx::core::metal {
const char* get_kernel_preamble() {
return R"preamble(
$CONTENT
)preamble";
}
} // namespace mlx::core::metal
EOF

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) { if (!is_floating_point(out.dtype())) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[matmul] Does not yet support non-floating point types.");

View File

@@ -63,15 +63,32 @@ std::function<void()> make_task(
auto s = arr.primitive().stream(); auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s); auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs(); auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs); {
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
if (p) { if (p) {
metal::device(s.device).end_encoding(s.index); metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s); scheduler::notify_new_task(s);
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable { [s, buffers = std::move(buffers), p = std::move(p)](
if (!arr.is_tracer()) { MTL::CommandBuffer* cbuf) {
arr.detach();
}
p->set_value(); p->set_value();
scheduler::notify_task_completion(s); scheduler::notify_task_completion(s);
check_error(cbuf); check_error(cbuf);
@@ -79,10 +96,7 @@ std::function<void()> make_task(
metal::device(s.device).commit_command_buffer(s.index); metal::device(s.device).commit_command_buffer(s.index);
} else { } else {
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[s, arr](MTL::CommandBuffer* cbuf) mutable { [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
if (!arr.is_tracer()) {
arr.detach();
}
check_error(cbuf); check_error(cbuf);
}); });
} }

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@@ -27,8 +27,8 @@ void binary_op(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt); set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0]; auto& out = outputs[0];
if (out.size() == 0) { if (out.size() == 0) {
@@ -60,7 +60,7 @@ void binary_op(
break; break;
} }
kname << op << type_to_name(a); kname << op << type_to_name(a);
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size(); kname << "_" << shape.size();
} }
@@ -69,8 +69,14 @@ void binary_op(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0); // - If a is donated it goes to the first output
set_array_buffer(compute_encoder, b, 1); // - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
set_array_buffer(
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
set_array_buffer(compute_encoder, outputs[0], 2); set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3); set_array_buffer(compute_encoder, outputs[1], 3);
@@ -122,7 +128,7 @@ void binary_op(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -152,7 +158,7 @@ void binary_op(
break; break;
} }
kname << op << type_to_name(a); kname << op << type_to_name(a);
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size(); kname << "_" << shape.size();
} }
@@ -161,8 +167,10 @@ void binary_op(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0); bool donate_a = a.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, b, 1); bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2); set_array_buffer(compute_encoder, out, 2);
if (bopt == General) { if (bopt == General) {
@@ -212,11 +220,15 @@ void unary_op(
auto& in = inputs[0]; auto& in = inputs[0];
bool contig = in.flags().contiguous; bool contig = in.flags().contiguous;
if (contig) { if (contig) {
out.set_data( if (in.is_donatable() && in.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(in.data_size() * out.itemsize()), out.move_shared_buffer(in);
in.data_size(), } else {
in.strides(), out.set_data(
in.flags()); allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
@@ -240,7 +252,8 @@ void unary_op(
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0); set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
if (!contig) { if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
@@ -473,6 +486,18 @@ void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh"); unary_op(inputs, out, "cosh");
} }
void CustomVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) { void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div"); binary_op(inputs, out, "div");
} }
@@ -769,4 +794,10 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = std::min(32, O); int bo = std::min(32, O);
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B); MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0); set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1); set_array_buffer(compute_encoder, scales, 1);

View File

@@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) {
return safe_div(n, m) * m; return safe_div(n, m) * m;
} }
inline bool is_64b_int(Dtype dtype) {
return dtype == int64 || dtype == uint64;
}
// All Reduce // All Reduce
void all_reduce_dispatch( void all_reduce_dispatch(
const array& in, const array& in,
array& out, array& out,
const std::string& op_name, const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
// Get kernel and encode buffers const Stream& s) {
size_t in_size = in.size(); Dtype out_dtype = out.dtype();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in)); bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in // We make sure each thread has enough to do by making it read in
// at least n_reads inputs // at least n_reads inputs
int n_reads = REDUCE_N_READS; int n_reads = REDUCE_N_READS;
size_t in_size = in.size();
// mod_in_size gives us the groups of n_reads needed to go over the entire // mod_in_size gives us the groups of n_reads needed to go over the entire
// input // input
uint mod_in_size = (in_size + n_reads - 1) / n_reads; uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size = thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size; mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups // If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups = safe_div(mod_in_size, thread_group_size); uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
@@ -66,7 +71,52 @@ void all_reduce_dispatch(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); // Encode buffers and dispatch
if (is_out_64b_int == false || n_thread_groups == 1) {
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
size_t intermediate_size = n_thread_groups;
array intermediate =
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// First dispatch
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads
// groups
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
void row_reduce_general_dispatch( void row_reduce_general_dispatch(
@@ -76,22 +126,31 @@ void row_reduce_general_dispatch(
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes, const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
auto kernel = const Stream& s) {
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS; int n_reads = REDUCE_N_READS;
size_t reduction_size = plan.shape.back(); size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape; auto shape = plan.shape;
auto strides = plan.strides; auto strides = plan.strides;
shape.pop_back(); shape.pop_back();
strides.pop_back(); strides.pop_back();
size_t non_row_reductions = 1; size_t non_row_reductions = 1;
for (auto s : shape) { for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s); non_row_reductions *= static_cast<size_t>(s);
} }
size_t out_size = out.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) { for (auto s : rem_shape) {
shape.push_back(s); shape.push_back(s);
@@ -101,16 +160,6 @@ void row_reduce_general_dispatch(
} }
int ndim = shape.size(); int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output // Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size = thread_group_size =
@@ -127,7 +176,88 @@ void row_reduce_general_dispatch(
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); if (is_out_64b_int == false || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
array intermediate = array(
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different post partial reduction in
// first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
// Set the arguments for the kernel
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
void strided_reduce_general_dispatch( void strided_reduce_general_dispatch(
@@ -137,9 +267,16 @@ void strided_reduce_general_dispatch(
const ReductionPlan& plan, const ReductionPlan& plan,
const std::vector<int>& axes, const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder, MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) { metal::Device& d,
auto kernel = const Stream& s) {
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel // Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back(); size_t reduction_size = plan.shape.back();
@@ -162,19 +299,7 @@ void strided_reduce_general_dispatch(
} }
int ndim = shape.size(); int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions // Select block dimensions
// Each thread reads 16 inputs to give it more work // Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS; uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output = uint n_threads_per_output =
@@ -183,14 +308,22 @@ void strided_reduce_general_dispatch(
// We spread outputs over the x dimension and inputs over the y dimension // We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same // Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumulates for that output // output and each thread in the y dimension accumulates for that output
// Threads with same lid.x, i.e. each column of threads work on same output
uint threadgroup_dim_x = std::min(out_size, 128ul); uint threadgroup_dim_x = std::min(out_size, 128ul);
// Number of threads along y, is dependent on number of reductions needed.
uint threadgroup_dim_y = uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x; kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y); threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
// Derive number of thread groups along x, based on how many threads we need
// along x
uint n_threadgroups_x = uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x; (out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
// Derive number of thread groups along y based on how many threads we need
// along y
uint n_threadgroups_y = uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
@@ -199,18 +332,122 @@ void strided_reduce_general_dispatch(
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a if (is_out_64b_int == false) {
// threadgroup - each thread must be able to update its accumulated output // Set the arguments for the kernel
// Note: Each threadgroup should have 32kB of data in threadgroup memory set_array_buffer(compute_encoder, in, 0);
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design set_array_buffer(compute_encoder, out, 1);
// This should be fine for floats, but we might need to revisit compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
// if we ever come to doubles. In that case, we should also cut compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
// down the number of threads we launch in a threadgroup compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setThreadgroupMemoryLength( compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), compute_encoder->setBytes(
0); strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); // We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
// groups
array intermediate = array(
{static_cast<int>(out.size()),
static_cast<int>(n_threadgroups_y * non_col_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
// collectively work on each output element.
reduction_size = n_threadgroups_y * non_col_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different after a partial reduction
// post first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
auto row_reduce_kernel = d.get_kernel(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
compute_encoder->setComputePipelineState(row_reduce_kernel);
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
size_t n_reads = REDUCE_N_READS;
size_t thread_group_size =
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = row_reduce_kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
uint n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
} }
} // namespace } // namespace
@@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
array in = inputs[0]; array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here // Make sure no identity reductions trickle down here
assert(!axes_.empty()); assert(!axes_.empty());
@@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or // Reducing over everything and the data is all there no broadcasting or
// slicing etc. // slicing etc.
if (plan.type == ContiguousAllReduce) { if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d); all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
} }
// At least the last dimension is row contiguous and we are reducing over // At least the last dimension is row contiguous and we are reducing over
@@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
else if ( else if (
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch( row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
// At least the last two dimensions are contiguous and we are doing a // At least the last two dimensions are contiguous and we are doing a
@@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan.type == ContiguousStridedReduce || plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) { plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch( strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
if (!copies.empty()) { if (!copies.empty()) {

View File

@@ -0,0 +1,55 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
std::vector<array> copies; std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) { auto check_input = [&copies, &s](const array& x) {
if (x.strides()[x.ndim() - 1] == 1) { bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -9,20 +9,6 @@ namespace mlx::core {
namespace { namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* compute_encoder,
MTL::ArgumentEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
// MTL::Resource usage through argument buffer needs to be explicitly
// flagged to enable hazard tracking
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
}
void set_array_buffer( void set_array_buffer(
MTL::ComputeCommandEncoder* enc, MTL::ComputeCommandEncoder* enc,
const array& a, const array& a,
@@ -117,16 +103,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) { collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between // Make a vector that has axes separated with -1. Collapse all axes between
// -1. // -1.
std::vector<int> to_collapse; std::vector<int> to_collapse;
if (xs[0].ndim() > 0) { if (shape.size() > 0) {
to_collapse.push_back(0); to_collapse.push_back(0);
for (int i = 1; i < xs[0].ndim(); i++) { for (int i = 1; i < shape.size(); i++) {
bool contiguous = true; bool contiguous = true;
for (auto& x : xs) { for (const std::vector<size_t>& st : strides) {
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { if (st[i] * shape[i] != st[i - 1]) {
contiguous = false; contiguous = false;
} }
if (!contiguous) { if (!contiguous) {
@@ -142,21 +130,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
} }
std::vector<int> out_shape; std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(xs.size()); std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) { for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = xs[0].shape()[to_collapse[i]]; int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) { while (to_collapse[++i] != -1) {
current_shape *= xs[0].shape()[to_collapse[i]]; current_shape *= shape[to_collapse[i]];
} }
out_shape.push_back(current_shape); out_shape.push_back(current_shape);
for (int j = 0; j < xs.size(); j++) { for (int j = 0; j < strides.size(); j++) {
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
} }
} }
return std::make_tuple(out_shape, out_strides); return std::make_tuple(out_shape, out_strides);
} }
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays> template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) { collapse_contiguous_dims(Arrays... xs) {

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/fast.h"
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
@@ -32,12 +33,16 @@ NO_GPU(AsType)
NO_GPU(AsStrided) NO_GPU(AsStrided)
NO_GPU(Broadcast) NO_GPU(Broadcast)
NO_GPU(Ceil) NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate) NO_GPU(Concatenate)
NO_GPU(Convolution) NO_GPU(Convolution)
NO_GPU(Copy) NO_GPU(Copy)
NO_GPU(Cos) NO_GPU(Cos)
NO_GPU(Cosh) NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide) NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(Remainder) NO_GPU(Remainder)
NO_GPU(Equal) NO_GPU(Equal)
NO_GPU(Erf) NO_GPU(Erf)
@@ -67,6 +72,7 @@ NO_GPU(NotEqual)
NO_GPU(Pad) NO_GPU(Pad)
NO_GPU(Partition) NO_GPU(Partition)
NO_GPU(Power) NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul) NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits) NO_GPU(RandomBits)
NO_GPU(Reduce) NO_GPU(Reduce)
@@ -89,6 +95,9 @@ NO_GPU(Subtract)
NO_GPU(Tan) NO_GPU(Tan)
NO_GPU(Tanh) NO_GPU(Tanh)
NO_GPU(Transpose) NO_GPU(Transpose)
NO_GPU_MULTI(DivMod)
namespace fast {
NO_GPU_MULTI(RoPE)
} // namespace fast
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,36 +1,168 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/compile.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
namespace mlx::core { namespace mlx::core {
namespace detail { constexpr int max_compile_depth = 10;
bool& compiler_disabled() { bool is_unary(const Primitive& p) {
auto get_val = []() { return (
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
return true; typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
} else { typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
return false; typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
} typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
}; typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
static bool compiler_disabled_ = get_val(); typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
return compiler_disabled_; typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh));
} }
#define MAX_OPS_PER_BUFFER max_ops_per_buffer() bool is_binary(const Primitive& p) {
return (
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
typeid(p) == typeid(Subtract));
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
}
namespace detail {
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs);
} // namespace detail
Compiled::Compiled(
Stream stream,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape,
std::unordered_set<uintptr_t> constant_ids)
: Primitive(stream),
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {}
std::vector<array> Compiled::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
}
std::vector<array> Compiled::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
}
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
}
bool Compiled::is_equivalent(const Primitive& other) const {
const Compiled& a_other = static_cast<const Compiled&>(other);
return std::equal(
tape_.begin(),
tape_.end(),
a_other.tape_.begin(),
a_other.tape_.end(),
[](const array& a1, const array& a2) {
auto& p1 = a1.primitive();
auto& p2 = a2.primitive();
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
});
}
void Compiled::print(std::ostream& os) {
os << "Compiled";
for (auto& a : tape_) {
a.primitive().print(os);
}
}
namespace detail {
CompileMode& compile_mode() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return CompileMode::disabled;
} else {
return CompileMode::enabled;
}
};
static CompileMode compile_mode_ = get_val();
return compile_mode_;
}
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>; using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap = using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>; std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination
void merge(array& dst, array& src, ParentsMap& parents_map) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dst
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
template <typename T, typename... U> template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) { size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...); typedef T(fnType)(U...);
@@ -54,14 +186,12 @@ struct CompilerCache {
// by the caller to avoid copying large tapes / inputs / outputs // by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) { CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
// Try to find the entry // Try to find the entry
auto inserted = cache_.insert({fun_id, {}}); auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = inserted.first->second; auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1, auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) { const std::vector<array>& in2) {
if (in1.size() != in2.size()) { if (in1.size() != in2.size()) {
throw std::runtime_error( return false;
"[compiler] Got different number of inputs to function,"
" this should never happen.");
} }
for (int i = 0; i < in1.size(); ++i) { for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) { if (in1[i].shape() != in2[i].shape()) {
@@ -205,28 +335,6 @@ void compile_simplify(
} }
} }
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check. // Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) { auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) { if (!a.has_primitive() || !b.has_primitive()) {
@@ -254,33 +362,32 @@ void compile_simplify(
return pa.is_equivalent(pb); return pa.is_equivalent(pb);
}; };
// Pass 0: fuse scalars // Merge scalars
std::vector<array> new_tape; std::vector<array> new_tape;
for (auto& arr : tape) { for (auto& arr : tape) {
// Check if we can fuse scalars // Check if we can merge scalars
if (is_scalar(arr)) { if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr)); auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) { if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr); merge(scalar->second, arr, parents_map);
// Don't keep orphaned scalars in the tape // Don't keep orphaned scalars in the tape
continue; continue;
} }
} }
new_tape.push_back(std::move(arr)); new_tape.push_back(std::move(arr));
} }
tape = std::move(new_tape); tape = std::move(new_tape);
std::unordered_set<uintptr_t> output_set; std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) { for (auto& o : outputs) {
output_set.insert(o.id()); output_set.insert(o.id());
} }
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape // Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) { for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) { for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the // Helper to check if we can merge the parents of the
// given array // given array
auto maybe_fuse_parents = [&](auto& a) { auto maybe_merge_parents = [&](auto& a) {
auto parents = parents_map.find(a.id()); auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) { if (parents != parents_map.end()) {
auto N = parents->second.size(); auto N = parents->second.size();
@@ -296,7 +403,7 @@ void compile_simplify(
auto& src = parents->second[j].first; auto& src = parents->second[j].first;
auto& dst = parents->second[i].first; auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) { if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src); merge(dst, src, parents_map);
mask[j] = true; mask[j] = true;
} }
} }
@@ -313,9 +420,9 @@ void compile_simplify(
} }
}; };
bool discard = maybe_fuse_parents(arr); bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
discard &= maybe_fuse_parents(s); discard &= maybe_merge_parents(s);
} }
// If an array and its siblings have no parents, and none of them are // If an array and its siblings have no parents, and none of them are
// outputs, it is safe to remove it from the tape // outputs, it is safe to remove it from the tape
@@ -327,6 +434,216 @@ void compile_simplify(
} }
} }
// Extract sub-graphs of the graph that can be compiled
// and replace them with a Compiled Primitive.
void compile_fuse(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Track outputs to replace with new compiled outputs
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
// Set of inputs to distinguish constants
std::unordered_set<uintptr_t> input_ids;
for (auto& in : inputs) {
input_ids.insert(in.id());
}
// Go through the tape in reverse order and check for fusable sub-graphs
std::vector<array> new_tape;
std::unordered_set<uintptr_t> global_cache;
for (int i = tape.size() - 1; i >= 0; --i) {
auto& arr = tape[i];
// Already compiled
if (global_cache.find(arr.id()) != global_cache.end()) {
continue;
}
// Two pass recursion:
// First pass:
// - Collect all the primitives which we can fuse with
// - Keeps a cache of fusable primitives which may be added out of
// DAG order. We have to determine if all of a fused primitive's
// outputs are also in the fused section, and this may not be the
// case the first time we visit it.
// Second pass:
// - Collect inputs to the new compiled primitive
// - Add fusable primitives to a tape in the correct order
std::function<void(const array&, int, const Stream&)> recurse;
std::unordered_set<uintptr_t> cache;
recurse = [&](const array& a, int depth, const Stream& s) {
if (cache.find(a.id()) != cache.end()) {
return;
}
// Stop fusing if:
// - Depth limit exceeded
// - Constant input
// - Stream mismatch
// - Non fusable primitive
if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
return;
}
bool all_parents_in = true;
if (depth > 0) {
// Guaranteed to have a parent since nested in the
// recursion.
auto& parents = parents_map.at(a.id());
for (auto& [p, idx] : parents) {
auto in_cache = cache.find(p.id()) != cache.end();
if (!in_cache) {
all_parents_in = false;
break;
}
}
}
// Arrays with a mix of parents outside the compilable section
// are not fusable
if (!all_parents_in) {
return;
}
cache.insert({a.id()});
for (auto& in : a.inputs()) {
recurse(in, depth + 1, s);
}
};
if (arr.has_primitive()) {
Stream s = arr.primitive().stream();
recurse(arr, 0, s);
}
// Not worth fusing a single primitive
if (cache.size() <= 1) {
new_tape.push_back(arr);
continue;
}
// Recurse a second time to build the tape in the right
// order and collect the inputs
std::unordered_set<uintptr_t> input_set;
std::vector<array> inputs;
std::vector<array> fused_tape;
std::unordered_set<uintptr_t> tape_set;
std::function<void(const array&)> recurse_tape;
recurse_tape = [&](const array& a) {
if (cache.find(a.id()) == cache.end()) {
if (input_set.find(a.id()) == input_set.end()) {
input_set.insert(a.id());
inputs.push_back(a);
}
return;
}
if (tape_set.find(a.id()) != tape_set.end()) {
return;
}
tape_set.insert(a.id());
for (auto& in : a.inputs()) {
recurse_tape(in);
}
fused_tape.push_back(a);
};
recurse_tape(arr);
std::vector<array> old_outputs;
// Add to global cache and add any global outputs to outputs
// of new primitive
for (int j = 0; j < fused_tape.size() - 1; ++j) {
auto& f = fused_tape[j];
if (output_map.find(f.id()) != output_map.end()) {
old_outputs.push_back(f);
// Parents are now siblings, update the parent map
auto& pairs = parents_map[f.id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) {
return cache.find(p.first.id()) != cache.end();
}),
pairs.end());
} else {
// Remove inner fused arrays parents from the parents map
// to keep the parents map in a valid state
parents_map.erase(f.id());
}
global_cache.insert({f.id()});
}
old_outputs.push_back(arr);
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
for (auto& o : old_outputs) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
std::unordered_set<uintptr_t> constant_ids;
for (auto& in : inputs) {
// Scalar constant
if (in.size() == 1 && !in.has_primitive() &&
input_ids.find(in.id()) == input_ids.end()) {
constant_ids.insert(in.id());
}
}
auto compiled_outputs = array::make_arrays(
shapes,
types,
std::make_shared<Compiled>(
old_outputs.back().primitive().stream(),
inputs,
old_outputs,
std::move(fused_tape),
std::move(constant_ids)),
inputs);
// One output per primitive
new_tape.push_back(compiled_outputs.back());
// Replace inputs old parents with compiled_outputs
for (int i = 0; i < inputs.size(); ++i) {
auto& pairs = parents_map[inputs[i].id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
pairs.end());
for (auto& o : compiled_outputs) {
pairs.push_back({o, i});
}
}
// - Update outputs parents to point to compiled outputs
// - Update any overall graph outputs to be compiled outputs
for (int o = 0; o < old_outputs.size(); ++o) {
merge(compiled_outputs[o], old_outputs[o], parents_map);
if (auto it = output_map.find(old_outputs[o].id());
it != output_map.end()) {
it->second = compiled_outputs[o];
}
}
}
std::reverse(new_tape.begin(), new_tape.end());
tape = std::move(new_tape);
// Replace output with potentially compiled output
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::vector<array> compile_replace( std::vector<array> compile_replace(
const std::vector<array>& tape, const std::vector<array>& tape,
const std::vector<array>& trace_inputs, const std::vector<array>& trace_inputs,
@@ -380,10 +697,17 @@ std::vector<array> compile_replace(
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) { size_t fun_id) {
if (compiler_disabled()) { if (compile_mode() == CompileMode::disabled) {
return fun; return fun;
} }
return [fun, fun_id](const std::vector<array>& inputs) { return [fun, fun_id](const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
})) {
return fun(inputs);
}
// Find a cache entry with the correct inputs // Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs); auto& entry = compiler_cache().find(fun_id, inputs);
@@ -402,10 +726,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
compile_dfs(entry.inputs, entry.outputs); compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape // Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3); if (compile_mode() != CompileMode::no_simplify) {
compile_simplify(
entry.tape, parents_map, entry.outputs, /* passes */ 3);
}
// This is a good point to do more optimizations, e.g. kernel fusion to // Kernel fusion to generate Compiled primitives. The tape and
// generate new primitives. The tape needs to be updated accordingly // new outputs must be updated accordingly
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
} }
// At this point we must have a tape, now replace the placeholders // At this point we must have a tape, now replace the placeholders
@@ -422,7 +752,7 @@ void compile_erase(size_t fun_id) {
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) { const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
if (detail::compiler_disabled()) { if (detail::compile_mode() == CompileMode::disabled) {
return fun; return fun;
} }
auto fun_id = detail::getAddress(fun); auto fun_id = detail::getAddress(fun);
@@ -430,11 +760,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
} }
void disable_compile() { void disable_compile() {
detail::compiler_disabled() = true; detail::compile_mode() = CompileMode::disabled;
} }
void enable_compile() { void enable_compile() {
detail::compiler_disabled() = false; detail::compile_mode() = CompileMode::enabled;
}
void set_compile_mode(CompileMode mode) {
detail::compile_mode() = mode;
} }
} // namespace mlx::core } // namespace mlx::core

28
mlx/compile.h Normal file
View File

@@ -0,0 +1,28 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
/** Set the compiler mode to the given value. */
void set_compile_mode(CompileMode mode);
} // namespace mlx::core

128
mlx/fast.cpp Normal file
View File

@@ -0,0 +1,128 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/transforms.h"
namespace mlx::core::fast {
std::vector<array> Custom::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
}
std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
}
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}
auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
if (traditional) {
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
x.shape(),
x.dtype(),
std::make_unique<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
{x});
}
return fallback({x})[0];
}
bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other);
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_);
}
} // namespace mlx::core::fast

Some files were not shown because too many files have changed in this diff Show More