Compare commits

...

81 Commits

Author SHA1 Message Date
Alex Barron
f5b0f11968 add fast::quantized_kv_update 2024-10-26 00:24:49 -07:00
Alex Barron
b509c2ad76 update bench 2024-10-25 12:10:24 -07:00
Alex Barron
852336b8a2 clean 2024-10-25 12:10:24 -07:00
Alex Barron
6649244686 revert sdpa 2024-10-25 12:10:24 -07:00
Alex Barron
047a584e3d 8 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
ef14b1e9c3 4 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
5824626c0b start 2024-10-25 12:10:24 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
c21331d47f version bump (#1477) 2024-10-10 13:05:17 -07:00
Awni Hannun
e1c9600da3 Add mx.random.permutation (#1471)
* random permutation

* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30 consistently handle all -inf in softmax (#1470) 2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087 Fix array is_available race cases (#1468) 2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739 Add the roll op (#1455) 2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d Bump nanobind to 2.2 (#1461)
* bump nanobind

* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40 Fix deep recursion with siblings (#1462)
* fix recursion with siblings

* fix

* add test

* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
95d04805b3 Fix complex power on Metal (#1460) 2024-10-06 19:58:30 -07:00
Awni Hannun
e4534dac17 Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
fef3c4ec1d Fix mpi test in CI (#1456)
* Fix mpi test in CI

* Set bind to none
2024-10-06 06:09:17 -07:00
Awni Hannun
1bdc038bf9 fix argpartition + faster {arg} sorts / partitions (#1453) 2024-10-03 14:21:25 -07:00
Awni Hannun
5523d9c426 faster cpu indexing (#1450) 2024-10-03 13:53:47 -07:00
Angelos Katharopoulos
d878015228 Fix normalization check_input (#1452) 2024-10-03 13:26:56 -07:00
Cheng
5900e3249f Fix building on Linux (#1446) 2024-09-30 07:00:39 -07:00
Angelos Katharopoulos
bacced53d3 Fix row reduce with very few rows (#1447) 2024-09-29 20:00:35 -07:00
Lucas Newman
4a64d4bff1 Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
b1e2b53c2d bump (#1445) 2024-09-27 13:53:02 -07:00
Awni Hannun
11354d5bff Avoid io timeout for large arrays (#1442) 2024-09-27 13:32:14 -07:00
Awni Hannun
718aea3f1d allow take to work with integer index (#1440) 2024-09-26 15:58:03 -07:00
Awni Hannun
5b6f38df2b Faster cpu ops (#1434)
* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
2024-09-26 09:19:13 -07:00
Awni Hannun
0b4a58699e Some overhead reductions in mx.fast.metal_kernel (#1437)
* some overhead reductions

* fix

* use +=

* use more +=
2024-09-25 17:25:21 -07:00
Awni Hannun
4f9f9ebb6f Faster Metal unary and binary for general case (#1431)
* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
2024-09-25 12:07:43 -07:00
Awni Hannun
afc9c0ec1b dtype is copy assignable (#1436) 2024-09-25 12:07:13 -07:00
Awni Hannun
195b429d99 Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad

* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Luke Carlson
2b878e9dd7 Create CITATION.cff (#1425) 2024-09-20 11:39:46 -07:00
Awni Hannun
67b6bf530d Optimization for general ND copies (#1421) 2024-09-17 17:59:51 -07:00
Nripesh Niketan
6af5ca35b2 feat: add cross_product (#1252)
* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Awni Hannun
4f46e9c997 More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
2024-09-17 12:46:31 -07:00
Awni Hannun
c6739ba7f3 Faster RNN layers (#1419)
* faster rnn

* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9 Data parallel helper (#1407) 2024-09-16 18:17:21 -07:00
jjuang-apple
8d68a3e805 remove fmt dependencies from MLX install (#1417) 2024-09-16 13:32:28 -07:00
jjuang-apple
6bbcc453ef avoid using find_library to make install truly portable (#1416) 2024-09-16 13:21:32 -07:00
Awni Hannun
d5ed4d7a71 override class function (#1418) 2024-09-16 13:21:04 -07:00
Nripesh Niketan
669c27140d Chore: add pre-commit hook for cmake (#1362)
* reset and lint

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-16 12:53:01 -07:00
Max-Heinrich Laves
adcc88e208 Conv cpu improvements (#1410) 2024-09-15 18:45:10 -07:00
Awni Hannun
d6492b0163 fix clip (#1415) 2024-09-14 16:09:09 -07:00
Awni Hannun
b3f52c9fbe ensure io/comm streams are active before eval (#1412) 2024-09-14 06:17:36 -07:00
c0g
bd8396fad8 Fix typo in transformer docs (#1414) 2024-09-14 06:05:15 -07:00
Angelos Katharopoulos
d0c58841d1 Patch bump (#1408) 2024-09-12 16:44:23 -07:00
Angelos Katharopoulos
881f09b2e2 Allow querying the allocator for the buffer size (#1404) 2024-09-11 21:02:16 -07:00
Awni Hannun
8b30acd7eb fix module attribute set, reset, set (#1403) 2024-09-11 16:30:42 -07:00
Awni Hannun
02efb310ca Xcode 160 (#1384)
* xcode 16.0 with debug tests

* limit nproc for builds

* vmap bug

* assert bug

* run python tests in debug mode

* fix view, bool copies preserve bits'

* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
e7e59c6f05 Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu

* Another copy scalar changed to fill

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-09-09 15:54:08 -07:00
Awni Hannun
3ae6aabe9f throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
xnorai
dc627dcb5e Replace the use of result_of_t with invoke_result_t (#1397)
* Fix C++20 incompatibility

* Fix C++20 incompatibility
2024-09-06 19:52:57 -07:00
Max-Heinrich Laves
efeb9c0f02 Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
206 changed files with 10133 additions and 5761 deletions

View File

@@ -13,8 +13,62 @@ parameters:
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@@ -31,15 +85,19 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run:
name: Generate package stubs
command: |
@@ -53,7 +111,9 @@ jobs:
- run:
name: Build CPP only
command: |
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
@@ -71,13 +131,13 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
brew install python@3.9
brew install openmpi
python3.8 -m venv env
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install numpy
pip install torch
pip install tensorflow
@@ -86,7 +146,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run:
name: Generate package stubs
command: |
@@ -99,7 +159,7 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
@@ -113,7 +173,7 @@ jobs:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
@@ -123,14 +183,23 @@ jobs:
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release:
parameters:
@@ -157,7 +226,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@@ -167,7 +236,7 @@ jobs:
command: |
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
name: Generate package stubs
@@ -180,7 +249,7 @@ jobs:
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when:
condition: << parameters.build_env >>
@@ -193,7 +262,7 @@ jobs:
- store_artifacts:
path: dist/
build_linux_test_release:
build_linux_release:
parameters:
python_version:
type: string
@@ -222,22 +291,28 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -255,8 +330,9 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
@@ -273,9 +349,17 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
@@ -290,7 +374,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -302,7 +386,7 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
@@ -313,17 +397,17 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_test_release:
- build_linux_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -14,3 +14,7 @@ repos:
- id: isort
args:
- --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
@@ -18,6 +18,7 @@ MLX was developed with contributions from the following individuals:
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

24
CITATION.cff Normal file
View File

@@ -0,0 +1,24 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -24,35 +24,43 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.17.2)
set(MLX_VERSION 0.19.0)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
message(
FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -61,63 +69,59 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if (MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL)
elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
if(${MACOS_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version
execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION
COMMAND_ERROR_IS_FATAL ANY)
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if (MLX_BUILD_CPU)
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -129,32 +133,29 @@ if (MLX_BUILD_CPU)
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
@@ -165,103 +166,95 @@ else()
endif()
find_package(MPI)
if (MPI_FOUND)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET
)
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "")
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING
"MPI found but mpirun is not available. Building without MPI."
)
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(
WARNING
"MPI which is not OpenMPI found. Building without MPI."
)
endif()
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
FetchContent_Declare(fmt
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if (MLX_BUILD_PYTHON_BINDINGS)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if (MLX_BUILD_TESTS)
if(MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if (MLX_BUILD_EXAMPLES)
if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if (MLX_BUILD_BENCHMARKS)
if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING PATTERN "*.h"
)
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING
PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies
if (MLX_BUILD_METAL)
if(MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source
)
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source)
endif()
@@ -273,31 +266,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION}
)
VERSION ${MLX_VERSION})
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
)
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
MLX_CMAKE_INSTALL_MODULE_DIR)
install(
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})

View File

@@ -0,0 +1,127 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,129 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,110 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,81 @@
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 65536
H = 32
H_k = 32 // 4
D = 128
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D))
return out
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
bits = 4
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -1,56 +1,41 @@
include(CMakeParseArguments)
###############################################################################
# ##############################################################################
# Build metal library
#
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
#
# Args:
# TARGET: Custom target to be added for the metal library
# TITLE: Name of the .metallib
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
#
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
COMMAND
xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM
)
VERBATIM)
# Add metallib custom target
add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
endmacro(mlx_build_metallib)
endmacro(mlx_build_metallib)

View File

@@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs["out"]
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
@@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs``
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The keys and values of ``output_shapes`` and ``output_dtypes``
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
@@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs={"inp": a},
template={"T": mx.float32},
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes={"out": a.shape},
output_dtypes={"out": a.dtype},
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs["out"]
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
@@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs={"x": x, "grid": grid},
template={"T": x.dtype},
output_shapes={"out": out_shape},
output_dtypes={"out": x.dtype},
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs["out"]
return outputs[0]
For a reasonably sized input such as:
@@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
@@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs={"x": x, "grid": grid, "cotangent": cotangent},
template={"T": x.dtype},
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs["x_grad"], outputs["grid_grad"]
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
@@ -74,20 +74,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with:
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again.
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -13,5 +13,8 @@ Linear Algebra
norm
cholesky
cholesky_inv
cross
qr
svd
eigvalsh
eigh

View File

@@ -14,6 +14,7 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -13,6 +13,7 @@ simple functions.
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx

View File

@@ -13,13 +13,18 @@ Layers
AvgPool1d
AvgPool2d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
@@ -31,6 +36,8 @@ Layers
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
@@ -46,6 +53,7 @@ Layers
RoPE
SELU
Sequential
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin

View File

@@ -45,6 +45,9 @@ Operations
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general
cos
cosh
@@ -77,6 +80,7 @@ Operations
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
@@ -118,14 +122,17 @@ Operations
pad
power
prod
put_along_axis
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save

View File

@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
truncated_normal
uniform
laplace
permutation

View File

@@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
@@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python
def gelu(x):
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0)))
Compiling Training Graphs
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`

View File

@@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
function. To get the second derivative you can do:
.. code-block:: shell
@@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
graphs.
.. note::
@@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
@@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop:
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
where the vectorized axes should be in the outputs.
Let's time these two different versions:

View File

@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
@@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:

View File

@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -116,7 +116,7 @@ saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

View File

@@ -3,10 +3,10 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
gradient with respect to the function's input.

View File

@@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
- Multiple arrays
* - Safetensors
- ``.safetensors``
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
:func:`load` depends on the format.
Here's an example of saving a single array to a file:

View File

@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
without needing to move them from one memory location to another. For example:
.. code-block:: python

View File

@@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
@@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
@@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
TARGET
mlx_ext_metallib
)
TITLE
mlx_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib)
endif()
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
NB_STATIC
STABLE_ABI
LTO
NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)

View File

@@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.17.0",
"nanobind==2.1.0",
"mlx>=0.18.0",
"nanobind==2.2.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.17.0
nanobind==2.1.0
mlx>=0.18.1
nanobind==2.2.0

View File

@@ -1,26 +1,24 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
if (MLX_BUILD_CPU)
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
@@ -28,17 +26,15 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
endif()
if (MLX_BUILD_METAL)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@@ -23,11 +23,22 @@ void free(Buffer buffer) {
}
Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)};
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr());
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {

View File

@@ -41,6 +41,7 @@ class Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;

View File

@@ -95,13 +95,29 @@ void array::detach() {
array_desc_->primitive = nullptr;
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
}
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this});
} else {
wait();
}
}
@@ -162,8 +178,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
@@ -242,25 +260,35 @@ array::ArrayDesc::~ArrayDesc() {
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
std::unordered_map<std::uintptr_t, array> input_map;
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
}
}
}
ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
append_deletable_inputs(*top);
}
}

View File

@@ -324,6 +324,10 @@ class array {
return array_desc_->data->buffer;
}
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
@@ -340,11 +344,33 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
}
enum Status { unscheduled, scheduled, available };
enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
bool is_available() const {
return status() == Status::available;
}
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const {
return array_desc_->status;

View File

@@ -1,10 +1,8 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)

View File

@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -18,49 +18,61 @@ void _qmm_t_4_64(
const float* biases,
int M,
int N,
int K) {
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
*result = simd_reduce_add(sum);
result++;
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
x += K;
}
}
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(),
M,
N,
K);
K,
B,
batched_w);
} else {
eval(inputs, out);
}

View File

@@ -33,8 +33,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x) {
x *= 1.442695; // multiply with log_2(e)
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
@@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

View File

@@ -1,5 +1,4 @@
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
@@ -7,72 +6,57 @@ else()
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${COMPILER}
${PROJECT_SOURCE_DIR}
${CLANG}
OUTPUT compiled_preamble.cpp
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if (IOS)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
endif()

View File

@@ -43,13 +43,15 @@ void set_binary_op_output_data(
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
@@ -64,7 +66,7 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
@@ -79,13 +81,13 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
} else if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
@@ -100,16 +102,14 @@ void set_binary_op_output_data(
}
break;
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
@@ -122,19 +122,7 @@ void set_binary_op_output_data(
}
}
struct UseDefaultBinaryOp {
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
struct UseDefaultBinaryOp {};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
@@ -150,18 +138,6 @@ struct DefaultVectorScalar {
a++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
};
template <typename T, typename U, typename Op>
@@ -178,18 +154,6 @@ struct DefaultScalarVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
};
template <typename T, typename U, typename Op>
@@ -206,204 +170,110 @@ struct DefaultVectorVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op>
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst += stride;
}
}
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
template <typename T, typename U, typename Op>
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
out += stride_out;
a += stride_a;
b += stride_b;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
dst += stride;
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
}
}
@@ -450,29 +320,33 @@ void binary_op(
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto& strides = out.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = out.ndim();
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
@@ -494,27 +368,27 @@ void binary_op(
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out, op);
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
}
@@ -531,9 +405,9 @@ void binary_op(
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
@@ -554,7 +428,8 @@ void binary_op(
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
@@ -569,7 +444,8 @@ void binary_op(
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
@@ -585,7 +461,8 @@ void binary_op(
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));

View File

@@ -9,168 +9,43 @@ namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out_a.size(); ++i) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
template <typename T, typename U, typename Op, int D>
void binary_op_dims(
const T* a,
const T* b,
U* out_a,
U* out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1>(
a,
b,
out_a,
out_b,
op,
shape,
a_strides,
b_strides,
out_strides,
axis + 1);
} else {
std::tie(*out_a, *out_b) = op(*a, *b);
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
a += stride_a;
b += stride_b;
out_a += stride_out;
out_b += stride_out;
}
}
@@ -181,352 +56,160 @@ void binary_op_dispatch_dims(
array& out_a,
array& out_b,
Op op) {
switch (out_a.ndim()) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
} else { // VectorVector
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
template <typename... Ops>
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Ops... ops) {
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, ops...);
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, ops...);
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, ops...);
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, ops...);
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, ops...);
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, ops...);
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, ops...);
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, ops...);
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, ops...);
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, ops...);
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, ops...);
binary_op<float>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, ops...);
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, ops...);
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}

View File

@@ -2,46 +2,12 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how

View File

@@ -156,8 +156,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
auto [shape, strides] = collapse_contiguous_dims(in);
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.

View File

@@ -4,6 +4,8 @@
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
@@ -12,22 +14,7 @@
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -44,15 +31,41 @@ void* compile(
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,8 +103,8 @@ void* compile(
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
@@ -103,10 +116,10 @@ void* compile(
}
// load library
libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -114,7 +127,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}
@@ -316,10 +329,7 @@ void Compiled::eval_cpu(
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +344,8 @@ void Compiled::eval_cpu(
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,13 +3,8 @@
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -684,6 +679,32 @@ void dispatch_slow_conv_3D(
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
template <typename T>
void flip_spatial_dims_inplace(array& wt) {
T* x = wt.data<T>();
size_t out_channels = wt.shape(0);
size_t in_channels = wt.shape(-1);
// Calculate the total size of the spatial dimensions
int spatial_size = 1;
for (int d = 1; d < wt.ndim() - 1; ++d) {
spatial_size *= wt.shape(d);
}
for (size_t i = 0; i < out_channels; i++) {
T* top = x + i * spatial_size * in_channels;
T* bottom =
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
for (size_t j = 0; j < spatial_size / 2; j++) {
for (size_t k = 0; k < in_channels; k++) {
std::swap(top[k], bottom[k]);
}
top += in_channels;
bottom -= in_channels;
}
}
}
void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
@@ -910,7 +931,8 @@ void explicit_gemm_conv_ND_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
@@ -1000,6 +1022,14 @@ void explicit_gemm_conv_ND_cpu(
copy(wt, gemm_wt, ctype);
}
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy(gemm_wt, gemm_wt_, CopyType::Vector);
flip_spatial_dims_inplace<float>(gemm_wt_);
gemm_wt = gemm_wt_;
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
@@ -1042,10 +1072,15 @@ void conv_1D_cpu(
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
@@ -1060,6 +1095,13 @@ void conv_2D_cpu(
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
@@ -1073,6 +1115,14 @@ void conv_3D_cpu(
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
@@ -1125,7 +1175,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
else {
std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions";
throw std::invalid_argument(msg.str());
}

View File

@@ -26,292 +26,117 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[0];
}
}
template <typename SrcT, typename DstT, typename StrideT, int D>
inline void copy_dims(
const SrcT* src,
DstT* dst,
const std::vector<int>& shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int axis) {
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = shape[axis];
template <typename SrcT, typename DstT>
inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src += stride_src;
dst += stride_dst;
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
if constexpr (D > 1) {
int axis = data_shape.size() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
}
} else {
int axis = data_shape.size() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
dst_ptr += stride_dst;
}
}
}
template <typename SrcT, typename DstT, typename stride_t>
template <typename SrcT, typename DstT, typename StrideT>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int64_t i_offset,
int64_t o_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
switch (new_shape.size()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
*dst_ptr = val;
return;
}
int size = std::accumulate(
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
src_offset,
dst_offset);
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, StrideT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename StrideT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
0,
0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) {
@@ -326,6 +151,7 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
}
}
@@ -426,7 +252,7 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
return copy_inplace_dispatch(src, dst, ctype);
copy_inplace_dispatch(src, dst, ctype);
}
void copy(const array& src, array& dst, CopyType ctype) {
@@ -456,20 +282,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename stride_t>
template <typename StrideT>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
copy_inplace_dispatch(
src,
dst,
ctype,
@@ -478,10 +304,10 @@ void copy_inplace(
o_strides,
i_offset,
o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
copy_inplace_dispatch(src, dst, ctype);
}
}

View File

@@ -1,14 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -81,11 +80,18 @@ void gather(
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]);
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
@@ -99,9 +105,10 @@ void gather(
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
}
@@ -223,21 +230,29 @@ void scatter(
update_size *= us;
}
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]);
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates);
auto out_loc = elem_to_loc(j, update_shape, out.strides());
op(updates.data<InT>()[update_loc],
out.data<InT>() + out_offset + out_loc);
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
}

View File

@@ -2,39 +2,19 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"

View File

@@ -295,6 +295,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
@@ -337,6 +344,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -406,16 +414,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -612,11 +611,18 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes);
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General);
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;

View File

@@ -2,14 +2,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>

View File

@@ -201,55 +201,61 @@ void _qmm_dispatch(
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
scales.data<float16_t>(),
biases.data<float16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
scales.data<bfloat16_t>(),
biases.data<bfloat16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}

View File

@@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
int axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out.data<T>() + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
}
@@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
int axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
@@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out.data<T>() + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
@@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator md(idx_ptr, axis_stride, kth);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -2,7 +2,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {

View File

@@ -41,7 +41,7 @@ void set_ternary_op_output_data(
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
if (is_donatable(x, out)) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
@@ -71,128 +71,46 @@ void set_ternary_op_output_data(
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims1(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
a,
b,
c,
out,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
axis + 1);
} else {
*out = op(*a, *b, *c);
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
a += stride_a;
b += stride_b;
c += stride_c;
out += stride_out;
}
}
@@ -203,30 +121,69 @@ void ternary_op_dispatch_dims(
const array& c,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 2:
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
int c_idx = elem_to_loc(i, c.shape(), c.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
ternary_op_dims<T1, T2, T3, U, Op, 1>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
}
}
@@ -243,10 +200,21 @@ void ternary_op(
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return;
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
} // namespace

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
namespace {
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
@@ -24,22 +24,36 @@ void set_unary_output_data(const array& in, array& out) {
}
}
template <typename T, typename Op>
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
T* dst = out.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
for (size_t i = 0; i < out.size(); ++i) {
// TODO this is super inefficient, need to fix.
int a_idx = elem_to_loc(i, a.shape(), a.strides());
dst[i] = op(a_ptr[a_idx]);
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
}
}
}

View File

@@ -0,0 +1,138 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<StrideT>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
if (shape[i] != 1) {
to_collapse.push_back(i);
}
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
};
if (i == to_collapse.size()) {
break;
}
int current_shape = shape[to_collapse[i]];
int k = i;
while (to_collapse[++k] != -1) {
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
}
if (!shape.empty() && out_shape.empty()) {
out_shape.push_back(1);
for (auto& out_stride : out_strides) {
out_stride.push_back(0);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (shape[i] == 1) {
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,12 +8,12 @@
namespace mlx::core {
template <typename stride_t>
inline stride_t elem_to_loc(
template <typename StrideT>
inline StrideT elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
stride_t loc = 0;
const std::vector<StrideT>& strides) {
StrideT loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -29,9 +29,9 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename stride_t>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1);
template <typename StrideT>
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<StrideT> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
@@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
@@ -105,36 +73,84 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
}
// The single array version of the above.
inline std::tuple<std::vector<int>, std::vector<size_t>>
collapse_contiguous_dims(
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::vector<int> collapsed_shape;
std::vector<size_t> collapsed_strides;
const std::vector<int64_t>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
std::numeric_limits<int>::max()) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i];
i--;
}
pos_[i]++;
loc += strides_[i];
}
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
}
}
return std::make_tuple(collapsed_shape, collapsed_strides);
}
void reset() {
loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
template <typename stride_t>
ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
StrideT loc{0};
private:
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
};
template <typename StrideT>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
const std::vector<StrideT>& strides) {
size_t no_broadcast_data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
@@ -155,4 +171,11 @@ inline auto check_contiguity(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
}
inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
}
} // namespace mlx::core

View File

@@ -1,99 +1,56 @@
function(make_jit_source SRC_FILE)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# This function takes a metal header file, runs the C preprocessesor on it,
# and makes the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}()
#
# To use the function, declare it in jit/includes.h and
# include jit/includes.h.
# To use the function, declare it in jit/includes.h and include
# jit/includes.h.
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
# Additional arguments to this function are treated as dependencies in the
# Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
)
OUTPUT jit/${SRC_NAME}.cpp
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source)
make_jit_source(
utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(hadamard)
if (MLX_METAL_JIT)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
if(MLX_METAL_JIT)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
make_jit_source(arange)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(
reduce
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
kernels/reduction/reduce_init.h
)
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
@@ -104,63 +61,52 @@ if (MLX_METAL_JIT)
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
kernels/steel/conv/loaders/loader_channel_n.h)
make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h)
make_jit_source(quantized)
make_jit_source(gemv_masked)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
endif()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
if (NOT MLX_METAL_PATH)
if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
target_compile_definitions(mlx
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit;
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
if (size == 0) {
@@ -205,7 +215,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)};
}
@@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
@@ -241,16 +254,14 @@ void MetalAllocator::free(Buffer buffer) {
}
}
size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
@@ -261,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal {
@@ -56,6 +57,7 @@ class MetalAllocator : public allocator::Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
return active_memory_;
};
@@ -71,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
@@ -81,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator
BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
std::mutex mutex_;

View File

@@ -19,14 +19,13 @@
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
int ndim,
int work_per_thread) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
@@ -43,14 +42,17 @@ std::string get_kernel_name(
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
if (ndim <= 3) {
kname << ndim;
} else {
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
}
break;
}
kname << op << type_to_name(a);
kname << "_" << op << type_to_name(a);
return kname.str();
}
@@ -69,52 +71,67 @@ void binary_op_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
auto maybe_collapse = [bopt, &a, &b, &out]() {
if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto ndim = shape.size();
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
auto kernel =
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
// otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_output_array(outputs[0], arg_idx++);
if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
@@ -125,9 +142,8 @@ void binary_op_gpu_inplace(
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@@ -164,72 +180,8 @@ void binary_op_gpu_inplace(
array& out,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
std::vector<array> outputs = {out};
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(

View File

@@ -13,6 +13,8 @@
namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
@@ -22,7 +24,9 @@ inline void build_kernel(
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim,
bool dynamic_dims) {
bool dynamic_dims,
bool use_big_index = false,
int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
@@ -37,8 +41,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments
for (auto& x : inputs) {
@@ -52,11 +56,11 @@ inline void build_kernel(
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]],\n";
}
}
@@ -68,52 +72,37 @@ inline void build_kernel(
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
<< ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {\n";
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
if (use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else {
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
}
// Read the inputs in tmps
int nc_in_count = 0;
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
@@ -123,56 +112,117 @@ inline void build_kernel(
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");" << std::endl;
os << ");\n";
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
<< xname << "[0];\n";
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
<< xname << "[index];\n";
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
nc_inputs.push_back(x);
}
}
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
if (ndim == 1) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
} else if (ndim == 2) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
} else if (ndim == 3) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os << " uint zpos = pos.z;\n";
if (dynamic_dims) {
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
}
os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os << " index_" << xname << " += ";
if (dynamic_dims) {
os << "l * in_strides[" << i << " * ndim + d];\n";
} else {
os << "l * in_strides[" << i * ndim << " + d];\n";
}
}
os << " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
<< namer.get_name(x.inputs()[0]) << ");\n";
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
<< ";\n";
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else {
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
}
}
os << " index++;\n }\n";
}
// Finish the kernel
os << "}" << std::endl;
os << "}\n";
if (cnt > 31) {
std::ostringstream msg;
@@ -195,10 +245,7 @@ void Compiled::eval_gpu(
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
auto lib = d.get_library(kernel_lib_, [&]() {
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
@@ -212,6 +259,17 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@@ -222,7 +280,9 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
}
build_kernel(
kernel,
@@ -233,10 +293,11 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
lib = d.get_library(kernel_lib_, kernel.str());
}
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ WORK_PER_THREAD);
return kernel.str();
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
@@ -285,7 +346,16 @@ void Compiled::eval_gpu(
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides);
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@@ -298,6 +368,8 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
} else if (use_2d) {
kernel_name += "_big";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -348,8 +420,10 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].size();
MTL::Size grid_dims(nthreads, 1, 1);
size_t nthreads = outputs[0].data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
@@ -357,11 +431,18 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
int pow2;
if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

View File

@@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped};
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
@@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
}
@@ -552,7 +557,7 @@ void winograd_conv_2D_gpu(
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
fill_gpu(zero_arr, in_padded, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
@@ -571,7 +576,6 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
copies_w.push_back(zero_arr);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
@@ -911,15 +915,11 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// Throw error
else {
throw std::invalid_argument(
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
}
// Clear copies
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
// Record copies
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <iostream>
#include <sstream>
#include "mlx/backend/metal/copy.h"
@@ -10,7 +11,7 @@
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
@@ -59,13 +60,25 @@ void copy_gpu_inplace(
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto maybe_collapse =
[ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto [shape, strides] = collapse_contiguous_dims(
data_shape,
std::vector{strides_in_pre, strides_out_pre},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
{
std::ostringstream kname;
@@ -83,9 +96,13 @@ void copy_gpu_inplace(
kname << "gg";
break;
}
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
@@ -105,10 +122,8 @@ void copy_gpu_inplace(
compute_encoder.set_output_array(out, 1, out_offset);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2);
}
@@ -117,10 +132,6 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -129,6 +140,11 @@ void copy_gpu_inplace(
data_size *= s;
int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
@@ -174,4 +190,31 @@ void copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -37,4 +37,7 @@ void copy_gpu_inplace(
CopyType ctype,
const Stream& s);
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
} // namespace mlx::core

View File

@@ -15,11 +15,11 @@ void CustomKernel::eval_gpu(
std::vector<array> copies;
for (auto& out : outputs) {
// Copy from previous kernel
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
array init = array(init_value_.value(), out.dtype());
copy_gpu(init, out, CopyType::Scalar, s);
copies.push_back(init);
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
}
}
@@ -33,24 +33,22 @@ void CustomKernel::eval_gpu(
return copies.back();
}
};
std::vector<const array> checked_inputs;
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i];
auto& shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index);
index++;
if (in.ndim() > 0) {
@@ -69,7 +67,7 @@ void CustomKernel::eval_gpu(
}
}
}
for (array out : outputs) {
for (auto& out : outputs) {
compute_encoder.set_output_array(out, index);
index++;
}
@@ -80,10 +78,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -20,7 +20,6 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
@@ -121,33 +120,34 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
enc_->endEncoding();
enc_->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
outputs_.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
enc_->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
@@ -156,39 +156,25 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs.insert(buf);
outputs_.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
enc_->dispatchThreads(grid_dims, group_dims);
}
Device::Device() {
@@ -199,12 +185,6 @@ Device::Device() {
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -219,69 +199,134 @@ void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
return get_stream_(index).buffer_ops;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
get_stream_(index).buffer_ops++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
stream.buffer->retain();
}
return bit->second.second;
return stream.buffer;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
void Device::end_encoding(int index) {
encoder_map_.erase(index);
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
return *(eit->second);
return *stream.encoder;
}
void Device::register_library(
@@ -293,20 +338,7 @@ void Device::register_library(
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
MTL::Library* Device::build_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
@@ -322,26 +354,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -465,68 +478,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
MTL::Library* Device::get_library_(const std::string& name) {
std::shared_lock lock(library_mtx_);
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
const std::function<std::string(void)>& builder) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto mtl_lib = build_library_(builder());
library_map_.insert({name, mtl_lib});
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -547,34 +524,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Single writer allowed
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
auto pool = new_scoped_memory_pool();
// Pull kernel from library
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
MTL::ComputePipelineState* Device::get_kernel(
@@ -583,16 +581,34 @@ MTL::ComputePipelineState* Device::get_kernel(
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
}
Device& device(mlx::core::Device) {

View File

@@ -7,6 +7,7 @@
#include <filesystem>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -44,13 +45,13 @@ struct CommandEncoder {
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
enc.concurrent_ = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
enc.concurrent_ = false;
enc.outputs_.insert(
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs_.clear();
}
private:
@@ -58,7 +59,7 @@ struct CommandEncoder {
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
@@ -69,18 +70,59 @@ struct CommandEncoder {
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder();
private:
void maybe_split();
// Inputs to all kernels in the encoder including temporaries
std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
return all_outputs_;
};
private:
MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
};
class Device {
@@ -114,29 +156,9 @@ class Device {
}
}
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
const std::function<std::string(void)>& builder);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
@@ -155,11 +177,20 @@ class Device {
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
void set_residency_set(const MTL::ResidencySet* residency_set);
private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Library* get_library_(const std::string& name);
MTL::Library* build_library_(const std::string& source_string);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
@@ -181,13 +212,22 @@ class Device {
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::ComputePipelineState* get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;
const MTL::ResidencySet* residency_set_{nullptr};
};
Device& device(mlx::core::Device);

View File

@@ -27,4 +27,9 @@ void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
}
bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
value();
}
} // namespace mlx::core

View File

@@ -575,8 +575,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -741,8 +740,8 @@ void fft_op(
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
}
void fft_op(
@@ -785,8 +784,7 @@ void nd_fft_op(
}
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
d.add_temporaries(std::move(temp_arrs), s.index);
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
@@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
@@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
@@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
@@ -171,33 +164,17 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -64,8 +64,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
@@ -83,8 +82,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
@@ -236,8 +235,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
@@ -264,7 +262,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
@@ -277,8 +275,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);

View File

@@ -1,100 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -1,26 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -1,12 +1,9 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
@@ -27,37 +24,38 @@ MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << u2_def << g_def;
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -67,7 +65,7 @@ void add_binary_kernels(
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@@ -78,31 +76,24 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
}};
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
}
MTL::ComputePipelineState* get_binary_kernel(
@@ -111,14 +102,13 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -128,15 +118,14 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -146,34 +135,26 @@ MTL::ComputePipelineState* get_ternary_kernel(
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
for (auto& [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -183,17 +164,33 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -203,8 +200,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
@@ -212,8 +208,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -226,22 +222,29 @@ MTL::ComputePipelineState* get_scan_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
auto lib = d.get_library(lib_name, [&]() {
auto out_type = get_type_string(out.dtype());
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
op[3] = toupper(op[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::scan();
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
{"contig_", "contiguous_scan"},
{"strided_", "strided_scan"},
}};
for (auto& [prefix, kernel] : scan_kernels) {
kernel_source << get_template_definition(
prefix + lib_name,
kernel,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op,
in.itemsize() <= 4 ? 4 : 2,
inclusive,
reverse);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -253,8 +256,7 @@ MTL::ComputePipelineState* get_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
@@ -279,8 +281,8 @@ MTL::ComputePipelineState* get_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -292,15 +294,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
{{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}}};
for (auto& [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
@@ -310,8 +311,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -319,8 +320,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
@@ -329,8 +329,8 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -344,8 +344,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source;
@@ -363,8 +362,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
}
lib = d.get_library(kernel_name, kernel_source.str());
}
return kernel_source.str();
});
auto st = d.get_kernel(kernel_name, lib);
return st;
}
@@ -383,8 +382,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
@@ -399,8 +397,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -419,8 +417,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
@@ -438,8 +435,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -450,19 +447,19 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -482,8 +479,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -507,8 +503,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -527,8 +523,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -550,8 +545,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -567,8 +562,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
@@ -582,8 +576,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -597,8 +591,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
@@ -611,8 +604,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -623,13 +616,12 @@ MTL::ComputePipelineState* get_fft_kernel(
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
std::string kernel_string;
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -638,13 +630,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

View File

@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
@@ -208,10 +209,10 @@ get_template_definition(std::string name, std::string func, Args... args) {
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
return fmt::format(
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
name,
s.str());
}
} // namespace mlx::core

View File

@@ -1,38 +1,26 @@
set(
BASE_HEADERS
bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h
)
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal
${METAL_FLAGS}
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
VERBATIM
)
VERBATIM)
endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
set(KERNEL_AIR
${TARGET}.air ${KERNEL_AIR}
PARENT_SCOPE)
endfunction(build_kernel)
build_kernel(arg_reduce)
@@ -43,105 +31,66 @@ build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(
scaled_dot_product_attention
scaled_dot_product_attention_params.h
steel/defines.h
steel/gemm/transforms.h
steel/utils.h
)
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
set(STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h)
if (NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h)
build_kernel(
fft
fft.h
fft/radix.h
fft/readwrite.h
)
build_kernel(
reduce
atomic.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
build_kernel(
quantized
quantized.h
${STEEL_HEADERS}
)
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(
steel/conv/kernels/steel_conv
${STEEL_HEADERS}
)
build_kernel(
steel/conv/kernels/steel_conv_general
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_fused
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_masked
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS}
)
build_kernel(gemv_masked steel/utils.h)
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h)
build_kernel(fft fft.h fft/radix.h fft/readwrite.h)
build_kernel(
reduce
atomic.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(gemv_masked steel/utils.h)
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
VERBATIM
)
VERBATIM)
add_custom_target(
mlx-metallib
DEPENDS
${MLX_METAL_PATH}/mlx.metallib
)
add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)
add_dependencies(
mlx
mlx-metallib
)
add_dependencies(mlx mlx-metallib)
# Install metallib
include(GNUInstallDirs)
@@ -149,5 +98,4 @@ include(GNUInstallDirs)
install(
FILES ${MLX_METAL_PATH}/mlx.metallib
DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT metallib
)
COMPONENT metallib)

View File

@@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
}
template <typename T, typename Op, int N_READS>
template <typename T, typename Op, int N_READS = 4>
[[kernel]] void arg_reduce_general(
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
const constant int* shape [[buffer(2)]],
const constant size_t* in_strides [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& ndim [[buffer(5)]],
const constant size_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
}
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
instantiate_kernel( \
"argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
instantiate_kernel( \
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t)

View File

@@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@@ -109,27 +109,11 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -140,7 +124,16 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;
idx.y += b_xstride;
}
}

View File

@@ -9,20 +9,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -217,7 +217,7 @@ struct Power {
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;

View File

@@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
@@ -137,32 +137,13 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -174,9 +155,18 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx++] = out[1];
idx.x += a_xstride;
idx.y += b_xstride;
}
}

View File

@@ -7,20 +7,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@@ -113,6 +113,7 @@ template <typename T, int N>
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
out += ws_ * kernel_stride;
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
@@ -126,7 +127,6 @@ template <typename T, int N>
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
}

View File

@@ -71,21 +71,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -94,10 +80,22 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
src_idx += src_xstride;
}
}
template <typename T, typename U>
@@ -136,20 +134,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -158,7 +143,22 @@ template <typename T, typename U>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);
idx.x += src_xstride;
idx.y += dst_xstride;
}
}

View File

@@ -16,12 +16,8 @@
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -8,6 +8,7 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
@@ -371,6 +372,64 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, int bits, int D>
METAL_FUNC void qmv_quad_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = D / QUAD_SIZE;
constexpr int packs_per_thread = values_per_thread / pack_factor;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int results_per_quadgroup = 8;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
U s = sl[0];
U b = bl[0];
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
for (int row = 0; row < results_per_quadgroup; row++) {
result[row] = quad_sum(result[row]);
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
y[row * quads_per_simd] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_fast_impl(
const device uint32_t* w,
@@ -586,10 +645,10 @@ METAL_FUNC void qmv_impl(
template <typename T, const int group_size, const int bits>
METAL_FUNC void qvm_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
@@ -697,16 +756,16 @@ template <
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -818,16 +877,16 @@ template <
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -942,6 +1001,45 @@ METAL_FUNC void qmm_n_impl(
}
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
const device T*& biases,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
@@ -996,7 +1094,58 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_quad_impl<T, group_size, bits, D>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1005,9 +1154,35 @@ template <typename T, int group_size, int bits>
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_fast_impl<T, group_size, bits>(
w,
scales,
@@ -1021,7 +1196,7 @@ template <typename T, int group_size, int bits>
simd_lid);
}
template <typename T, const int group_size, const int bits>
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1030,9 +1205,35 @@ template <typename T, const int group_size, const int bits>
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_impl<T, group_size, bits>(
w,
scales,
@@ -1046,23 +1247,49 @@ template <typename T, const int group_size, const int bits>
simd_lid);
}
template <typename T, const int group_size, const int bits>
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qvm_impl<T, group_size, bits>(
x,
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
@@ -1076,18 +1303,27 @@ template <
const int group_size,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1099,26 +1335,53 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
const int bits,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1131,8 +1394,27 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
@@ -1141,23 +1423,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1202,23 +1484,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1259,27 +1541,27 @@ template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1306,10 +1588,10 @@ template <typename T, int group_size, int bits>
b_strides,
tid);
qvm_impl<T, group_size, bits>(
x,
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
@@ -1327,28 +1609,28 @@ template <
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1383,7 +1665,7 @@ template <
b_strides,
tid);
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <
@@ -1394,28 +1676,28 @@ template <
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1451,17 +1733,17 @@ template <
b_strides,
tid);
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
METAL_FUNC void affine_quantize_impl(
const device T* w,
device uint8_t* out,
device T* scales,
device T* biases,
uint2 index,
uint2 grid_dim) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
@@ -1538,6 +1820,18 @@ template <typename T, const int group_size, const int bits>
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
affine_quantize_impl<T, group_size, bits>(
w, out, scales, biases, index, grid_dim);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
@@ -1601,3 +1895,41 @@ template <typename T, const int group_size, const int bits>
out[oindex + i] = scale * d + bias;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void kv_update(
const device T* new_keys [[buffer(0)]],
const device T* new_values [[buffer(1)]],
device uint8_t* keys [[buffer(2)]],
device T* key_scales [[buffer(3)]],
device T* key_biases [[buffer(4)]],
device uint8_t* values [[buffer(5)]],
device T* value_scales [[buffer(6)]],
device T* value_biases [[buffer(7)]],
const constant int& offset,
const constant int& batch_stride,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
// Get the right offset in the thing
// Need to use the head dim too
constexpr int pack_factor = 8 / bits;
uint batch_idx = index.y * batch_stride * 4 + offset;
new_keys += index.y * 128;
new_values += index.y * 128;
// uint batch_idx = offset;
// // Index to correct slice
uint group_idx = batch_idx * pack_factor / group_size;
keys += batch_idx;
key_scales += group_idx;
key_biases += group_idx;
values += batch_idx;
value_scales += group_idx;
value_biases += group_idx;
uint2 new_index = {index.x, 0};
affine_quantize_impl<T, group_size, bits>(
new_keys, keys, key_scales, key_biases, new_index, grid_dim);
affine_quantize_impl<T, group_size, bits>(
new_values, values, value_scales, value_biases, new_index, grid_dim);
}

View File

@@ -5,67 +5,104 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_types(name, group_size, bits) \
instantiate_quantized(name, float, group_size, bits) \
instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_quantized(name, bfloat16_t, group_size, bits)
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched)
#define instantiate_quantized_groups(name, bits) \
instantiate_quantized_types(name, 128, bits) \
instantiate_quantized_types(name, 64, bits) \
instantiate_quantized_types(name, 32, bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)
#define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, float, group_size, bits, false) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched)
#define instantiate_quantized_groups_aligned(name, bits) \
instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_quantized_types_aligned(name, 32, bits)
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
D, \
batched)
#define instantiate_quantized_all_aligned(name) \
instantiate_quantized_groups_aligned(name, 2) \
instantiate_quantized_groups_aligned(name, 4) \
instantiate_quantized_groups_aligned(name, 8) \
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
instantiate_quantized_all_aligned(qmm_t)
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
device const int& ndim,
device const int* key_shape,
device const size_t* key_strides,
constant const int& ndim,
constant const int* key_shape,
constant const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;

View File

@@ -1,11 +1,11 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
using namespace mlx::steel;
@@ -886,6 +886,9 @@ template <
}
}
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
@@ -922,548 +925,42 @@ instantiate_fast_inference_self_attention_kernel(
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
template <
typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim)
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) {
smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \
instantiate_sdpa_vector(type, 96) \
instantiate_sdpa_vector(type, 128)
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
instantiate_kernel( \
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector, type, head_dim, group_size, bits)
device T4* simdgroupQueryData = (device T4*)baseQ;
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)
#pragma clang loop unroll(full)
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 96) \
instantiate_quant_sdpa_vector_group_size(type, 128)
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)
T4 thread_data_x4;
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
// clang-format on

Some files were not shown because too many files have changed in this diff Show More