Compare commits

...

81 Commits

Author SHA1 Message Date
Alex Barron
f5b0f11968 add fast::quantized_kv_update 2024-10-26 00:24:49 -07:00
Alex Barron
b509c2ad76 update bench 2024-10-25 12:10:24 -07:00
Alex Barron
852336b8a2 clean 2024-10-25 12:10:24 -07:00
Alex Barron
6649244686 revert sdpa 2024-10-25 12:10:24 -07:00
Alex Barron
047a584e3d 8 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
ef14b1e9c3 4 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
5824626c0b start 2024-10-25 12:10:24 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
c21331d47f version bump (#1477) 2024-10-10 13:05:17 -07:00
Awni Hannun
e1c9600da3 Add mx.random.permutation (#1471)
* random permutation

* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30 consistently handle all -inf in softmax (#1470) 2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087 Fix array is_available race cases (#1468) 2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739 Add the roll op (#1455) 2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d Bump nanobind to 2.2 (#1461)
* bump nanobind

* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40 Fix deep recursion with siblings (#1462)
* fix recursion with siblings

* fix

* add test

* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
95d04805b3 Fix complex power on Metal (#1460) 2024-10-06 19:58:30 -07:00
Awni Hannun
e4534dac17 Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
fef3c4ec1d Fix mpi test in CI (#1456)
* Fix mpi test in CI

* Set bind to none
2024-10-06 06:09:17 -07:00
Awni Hannun
1bdc038bf9 fix argpartition + faster {arg} sorts / partitions (#1453) 2024-10-03 14:21:25 -07:00
Awni Hannun
5523d9c426 faster cpu indexing (#1450) 2024-10-03 13:53:47 -07:00
Angelos Katharopoulos
d878015228 Fix normalization check_input (#1452) 2024-10-03 13:26:56 -07:00
Cheng
5900e3249f Fix building on Linux (#1446) 2024-09-30 07:00:39 -07:00
Angelos Katharopoulos
bacced53d3 Fix row reduce with very few rows (#1447) 2024-09-29 20:00:35 -07:00
Lucas Newman
4a64d4bff1 Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
b1e2b53c2d bump (#1445) 2024-09-27 13:53:02 -07:00
Awni Hannun
11354d5bff Avoid io timeout for large arrays (#1442) 2024-09-27 13:32:14 -07:00
Awni Hannun
718aea3f1d allow take to work with integer index (#1440) 2024-09-26 15:58:03 -07:00
Awni Hannun
5b6f38df2b Faster cpu ops (#1434)
* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
2024-09-26 09:19:13 -07:00
Awni Hannun
0b4a58699e Some overhead reductions in mx.fast.metal_kernel (#1437)
* some overhead reductions

* fix

* use +=

* use more +=
2024-09-25 17:25:21 -07:00
Awni Hannun
4f9f9ebb6f Faster Metal unary and binary for general case (#1431)
* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
2024-09-25 12:07:43 -07:00
Awni Hannun
afc9c0ec1b dtype is copy assignable (#1436) 2024-09-25 12:07:13 -07:00
Awni Hannun
195b429d99 Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad

* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Luke Carlson
2b878e9dd7 Create CITATION.cff (#1425) 2024-09-20 11:39:46 -07:00
Awni Hannun
67b6bf530d Optimization for general ND copies (#1421) 2024-09-17 17:59:51 -07:00
Nripesh Niketan
6af5ca35b2 feat: add cross_product (#1252)
* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Awni Hannun
4f46e9c997 More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
2024-09-17 12:46:31 -07:00
Awni Hannun
c6739ba7f3 Faster RNN layers (#1419)
* faster rnn

* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9 Data parallel helper (#1407) 2024-09-16 18:17:21 -07:00
jjuang-apple
8d68a3e805 remove fmt dependencies from MLX install (#1417) 2024-09-16 13:32:28 -07:00
jjuang-apple
6bbcc453ef avoid using find_library to make install truly portable (#1416) 2024-09-16 13:21:32 -07:00
Awni Hannun
d5ed4d7a71 override class function (#1418) 2024-09-16 13:21:04 -07:00
Nripesh Niketan
669c27140d Chore: add pre-commit hook for cmake (#1362)
* reset and lint

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-16 12:53:01 -07:00
Max-Heinrich Laves
adcc88e208 Conv cpu improvements (#1410) 2024-09-15 18:45:10 -07:00
Awni Hannun
d6492b0163 fix clip (#1415) 2024-09-14 16:09:09 -07:00
Awni Hannun
b3f52c9fbe ensure io/comm streams are active before eval (#1412) 2024-09-14 06:17:36 -07:00
c0g
bd8396fad8 Fix typo in transformer docs (#1414) 2024-09-14 06:05:15 -07:00
Angelos Katharopoulos
d0c58841d1 Patch bump (#1408) 2024-09-12 16:44:23 -07:00
Angelos Katharopoulos
881f09b2e2 Allow querying the allocator for the buffer size (#1404) 2024-09-11 21:02:16 -07:00
Awni Hannun
8b30acd7eb fix module attribute set, reset, set (#1403) 2024-09-11 16:30:42 -07:00
Awni Hannun
02efb310ca Xcode 160 (#1384)
* xcode 16.0 with debug tests

* limit nproc for builds

* vmap bug

* assert bug

* run python tests in debug mode

* fix view, bool copies preserve bits'

* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
e7e59c6f05 Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu

* Another copy scalar changed to fill

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-09-09 15:54:08 -07:00
Awni Hannun
3ae6aabe9f throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
xnorai
dc627dcb5e Replace the use of result_of_t with invoke_result_t (#1397)
* Fix C++20 incompatibility

* Fix C++20 incompatibility
2024-09-06 19:52:57 -07:00
Max-Heinrich Laves
efeb9c0f02 Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
206 changed files with 10133 additions and 5761 deletions

View File

@@ -13,8 +13,62 @@ parameters:
test_release: test_release:
type: boolean type: boolean
default: false default: false
linux_release:
type: boolean
default: false
jobs: jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test: linux_build_and_test:
docker: docker:
- image: cimg/python:3.9 - image: cimg/python:3.9
@@ -31,15 +85,19 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.1.0 pip install nanobind==2.2.0
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -53,7 +111,9 @@ jobs:
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run: - run:
name: Run CPP tests name: Run CPP tests
command: ./build/tests/tests command: ./build/tests/tests
@@ -71,13 +131,13 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.8 brew install python@3.9
brew install openmpi brew install openmpi
python3.8 -m venv env python3.9 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.1.0 pip install nanobind==2.2.0
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@@ -86,7 +146,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -99,7 +159,7 @@ jobs:
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
@@ -113,7 +173,7 @@ jobs:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run CPP tests name: Run CPP tests
command: | command: |
@@ -123,14 +183,23 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
make -j -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release: build_release:
parameters: parameters:
@@ -157,7 +226,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.1.0 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install twine pip install twine
@@ -167,7 +236,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -180,7 +249,7 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> \ << parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
@@ -193,7 +262,7 @@ jobs:
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
build_linux_test_release: build_linux_release:
parameters: parameters:
python_version: python_version:
type: string type: string
@@ -222,22 +291,28 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.1.0 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -255,8 +330,9 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation
build_pypi_release: build_pypi_release:
when: when:
@@ -273,9 +349,17 @@ workflows:
ignore: /.*/ ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb: prb:
when: when:
matches: matches:
@@ -290,7 +374,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -302,7 +386,7 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
weekly_build: weekly_build:
when: when:
@@ -313,17 +397,17 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
linux_test_release: linux_test_release:
when: when:
and: and:
- equal: [ main, << pipeline.git.branch >> ] - equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >> - << pipeline.parameters.linux_release >>
jobs: jobs:
- build_linux_test_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]

View File

@@ -14,3 +14,7 @@ repos:
- id: isort - id: isort
args: args:
- --profile=black - --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
@@ -18,6 +18,7 @@ MLX was developed with contributions from the following individuals:
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

24
CITATION.cff Normal file
View File

@@ -0,0 +1,24 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -24,35 +24,43 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.17.2) set(MLX_VERSION 0.19.0)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}") message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF) set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR message(
"Building for x86_64 on macOS is not supported." FATAL_ERROR
" If you are on an Apple silicon system, check the build" "Building for x86_64 on macOS is not supported."
" documentation for possible fixes: " " If you are on an Apple silicon system, check the build"
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") " documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
include(FetchContent) include(FetchContent)
@@ -61,63 +69,59 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal) set(METAL_LIB "-framework Metal")
find_library(FOUNDATION_LIB Foundation) set(FOUNDATION_LIB "-framework Foundation")
find_library(QUARTZ_LIB QuartzCore) set(QUARTZ_LIB "-framework QuartzCore")
endif() endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF) set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL) elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
endif() endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" execute_process(
OUTPUT_VARIABLE MACOS_VERSION COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0) if(${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" ) message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif() endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version # Get the metal version
execute_process( execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" COMMAND
OUTPUT_VARIABLE MLX_METAL_VERSION zsh "-c"
COMMAND_ERROR_IS_FATAL ANY) "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare( FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> $<INSTALL_INTERFACE:include/metal_cpp>)
$<INSTALL_INTERFACE:include/metal_cpp> target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
)
target_link_libraries(
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}") add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -129,32 +133,29 @@ if (MLX_BUILD_CPU)
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
set(BLA_VENDOR OpenBLAS) set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif() endif()
# Search and link with lapack. # Search and link with lapack.
find_package(LAPACK REQUIRED) find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND) if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed") message(FATAL_ERROR "Must have LAPACK installed")
endif() endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/include /usr/local/opt/openblas/include)
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version # List blas after lapack otherwise we may accidentally incldue an old
# of lapack.h from the include dirs of blas. # version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND) if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed") message(FATAL_ERROR "Must have BLAS installed")
endif() endif()
# TODO find a cleaner way to do this # TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
/usr/include $ENV{BLAS_HOME}/include)
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
@@ -165,103 +166,95 @@ else()
endif() endif()
find_package(MPI) find_package(MPI)
if (MPI_FOUND) if(MPI_FOUND)
execute_process( execute_process(
COMMAND zsh "-c" "mpirun --version" COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET ERROR_QUIET)
) if(${MPI_VERSION} MATCHES ".*Open MPI.*")
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "") elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE) set(MPI_FOUND FALSE)
message( message(
WARNING WARNING "MPI found but mpirun is not available. Building without MPI.")
"MPI found but mpirun is not available. Building without MPI."
)
else() else()
set(MPI_FOUND FALSE) set(MPI_FOUND FALSE)
message( message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
WARNING endif()
"MPI which is not OpenMPI found. Building without MPI."
)
endif()
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
mlx mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
PUBLIC $<INSTALL_INTERFACE:include>)
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 GIT_TAG 10.2.1
EXCLUDE_FROM_ALL EXCLUDE_FROM_ALL)
)
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if (MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()
if (MLX_BUILD_TESTS) if(MLX_BUILD_TESTS)
include(CTest) include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif() endif()
if (MLX_BUILD_EXAMPLES) if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif() endif()
if (MLX_BUILD_BENCHMARKS) if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif() endif()
# ----------------------------- Installation ----------------------------- # ----------------------------- Installation -----------------------------
include(GNUInstallDirs) include(GNUInstallDirs)
# Install library # Install library
install( install(
TARGETS mlx TARGETS mlx
EXPORT MLXTargets EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} INCLUDES
) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers # Install headers
install( install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers COMPONENT headers
FILES_MATCHING PATTERN "*.h" FILES_MATCHING
) PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies # Install metal dependencies
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
# Install metal cpp # Install metal cpp
install( install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/ DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source COMPONENT metal_cpp_source)
)
endif() endif()
@@ -273,31 +266,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install( install(
EXPORT MLXTargets EXPORT MLXTargets
FILE MLXTargets.cmake FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
write_basic_package_version_file( write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION} VERSION ${MLX_VERSION})
)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
) MLX_CMAKE_INSTALL_MODULE_DIR)
install( install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install( install(DIRECTORY ${CMAKE_MODULE_PATH}/
DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@@ -0,0 +1,127 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,129 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,110 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,81 @@
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 65536
H = 32
H_k = 32 // 4
D = 128
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D))
return out
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
bits = 4
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -1,56 +1,41 @@
include(CMakeParseArguments) include(CMakeParseArguments)
############################################################################### # ##############################################################################
# Build metal library # Build metal library
# #
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
# #
# Args: # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# TARGET: Custom target to be added for the metal library # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# TITLE: Name of the .metallib # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib # files (like headers)
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
# #
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments( cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output # Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET} OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal COMMAND
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>" xcrun -sdk macosx metal
${MTLLIB_COMPILE_OPTIONS} "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_SOURCES} ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib" COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM VERBATIM)
)
# Add metallib custom target # Add metallib custom target
add_custom_target( add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib) endmacro(mlx_build_metallib)

View File

@@ -19,17 +19,19 @@ Let's write a custom kernel that computes ``exp`` elementwise:
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"],
output_names=["out"],
source=source, source=source,
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
) )
return outputs["out"] return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a) b = exp_elementwise(a)
@@ -40,16 +42,16 @@ Let's write a custom kernel that computes ``exp`` elementwise:
The full function signature will be generated using: The full function signature will be generated using:
* The keys and shapes/dtypes of ``inputs`` * The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature. so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``. in ``source``.
* The keys and values of ``output_shapes`` and ``output_dtypes`` * The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``. so we add ``device float16_t* out``.
* Template parameters passed using ``template`` * Template parameters passed using ``template``
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``. and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` * Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
@@ -104,18 +106,20 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp_strided", name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source source=source
) )
outputs = kernel( outputs = kernel(
inputs={"inp": a}, inputs=[a],
template={"T": mx.float32}, template=[("T", mx.float32)],
grid=(a.size, 1, 1), grid=(a.size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
output_shapes={"out": a.shape}, output_shapes=[a.shape],
output_dtypes={"out": a.dtype}, output_dtypes=[a.dtype],
ensure_row_contiguous=False, ensure_row_contiguous=False,
) )
return outputs["out"] return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16) a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous # make non-contiguous
@@ -243,17 +247,19 @@ First we'll implement the forward pass as a fused kernel:
""" """
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="grid_sample", name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source, source=source,
) )
outputs = kernel( outputs = kernel(
inputs={"x": x, "grid": grid}, inputs=[x, grid],
template={"T": x.dtype}, template=[("T", x.dtype)],
output_shapes={"out": out_shape}, output_shapes=[out_shape],
output_dtypes={"out": x.dtype}, output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1), grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
) )
return outputs["out"] return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
@@ -389,6 +395,8 @@ We can then implement the backwards pass as follows:
""" """
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="grid_sample_grad", name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source, source=source,
atomic_outputs=True, atomic_outputs=True,
) )
@@ -398,15 +406,15 @@ We can then implement the backwards pass as follows:
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded grid_size = B * gN * gM * C_padded
outputs = kernel( outputs = kernel(
inputs={"x": x, "grid": grid, "cotangent": cotangent}, inputs=[x, grid, cotangent],
template={"T": x.dtype}, template=[("T", x.dtype)],
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape}, output_shapes=[x.shape, grid.shape],
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype}, output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1), grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1), threadgroup=(256, 1, 1),
init_value=0, init_value=0,
) )
return outputs["x_grad"], outputs["grid_grad"] return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements: To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.8 - Using a native Python >= 3.9
- macOS >= 13.5 - macOS >= 13.5
.. note:: .. note::
@@ -74,20 +74,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]" CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell: .. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively. Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm, To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"`` If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported." but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again. wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -13,5 +13,8 @@ Linear Algebra
norm norm
cholesky cholesky
cholesky_inv cholesky_inv
cross
qr qr
svd svd
eigvalsh
eigh

View File

@@ -14,6 +14,7 @@ Metal
get_cache_memory get_cache_memory
set_memory_limit set_memory_limit
set_cache_limit set_cache_limit
set_wired_limit
clear_cache clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -13,6 +13,7 @@ simple functions.
:template: nn-module-template.rst :template: nn-module-template.rst
elu elu
celu
gelu gelu
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx

View File

@@ -13,13 +13,18 @@ Layers
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
BatchNorm BatchNorm
CELU
Conv1d Conv1d
Conv2d Conv2d
Conv3d Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout Dropout
Dropout2d Dropout2d
Dropout3d Dropout3d
Embedding Embedding
ELU
GELU GELU
GLU GLU
GroupNorm GroupNorm
@@ -31,6 +36,8 @@ Layers
LayerNorm LayerNorm
LeakyReLU LeakyReLU
Linear Linear
LogSigmoid
LogSoftmax
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
@@ -46,6 +53,7 @@ Layers
RoPE RoPE
SELU SELU
Sequential Sequential
Sigmoid
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softmin Softmin

View File

@@ -45,6 +45,9 @@ Operations
conv1d conv1d
conv2d conv2d
conv3d conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general conv_general
cos cos
cosh cosh
@@ -77,6 +80,7 @@ Operations
greater_equal greater_equal
hadamard_transform hadamard_transform
identity identity
imag
inner inner
isfinite isfinite
isclose isclose
@@ -118,14 +122,17 @@ Operations
pad pad
power power
prod prod
put_along_axis
quantize quantize
quantized_matmul quantized_matmul
radians radians
real
reciprocal reciprocal
remainder remainder
repeat repeat
reshape reshape
right_shift right_shift
roll
round round
rsqrt rsqrt
save save

View File

@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
truncated_normal truncated_normal
uniform uniform
laplace laplace
permutation

View File

@@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function # Compile the function
compiled_fun = mx.compile(fun) compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32) # Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y)) print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same The output of both the regular function and the compiled function is the same
up to numerical precision. up to numerical precision.
The first time you call a compiled function, MLX will build the compute The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled slow. However, MLX will cache compiled functions, so calling a compiled
@@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python .. code-block:: python
def gelu(x): def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you If you use this function with small arrays, it will be overhead bound. If you
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging Debugging
--------- ---------
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0))) print(fun(mx.array(1.0)))
Compiling Training Graphs Compiling Training Graphs
------------------------- -------------------------
This section will step through how to use :func:`compile` with a simple example This section will step through how to use :func:`compile` with a simple example
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation: To start, here is the simple example without any compilation:
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled: appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output # The state that will be captured as input and output
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(x, y): def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`. function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile` good practice is to compile the outer most function to give :func:`compile`

View File

@@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do: function. To get the second derivative you can do:
.. code-block:: shell .. code-block:: shell
@@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff: .. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit Automatic differentiation in MLX works on functions rather than on implicit
graphs. graphs.
.. note:: .. note::
@@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y): def loss_fn(params, x, y):
w, b = params["weight"], params["bias"] w, b = params["weight"], params["bias"]
h = w * x + b h = w * x + b
return mx.mean(mx.square(h - y)) return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients. Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that. part of the function. You can use the :func:`stop_gradient` for that.
@@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop:
Instead you can use :func:`vmap` to automatically vectorize the addition: Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python .. code-block:: python
# Vectorize over the second dimension of x and the # Vectorize over the second dimension of x and the
# first dimension of y # first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs. where the vectorized axes should be in the outputs.
Let's time these two different versions: Let's time these two different versions:

View File

@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell .. code-block:: shell
>>> arr = mx.arange(10) >>> arr = mx.arange(10)
>>> idx = mx.array([5, 7]) >>> idx = mx.array([5, 7])
>>> arr[idx] >>> arr[idx]
array([5, 7], dtype=int32) array([5, 7], dtype=int32)
@@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs
operations which MLX does not yet support include :func:`numpy.nonzero` and the operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`. single input version of :func:`numpy.where`.
In Place Updates In Place Updates
---------------- ----------------
In place updates to indexed arrays are possible in MLX. For example: In place updates to indexed arrays are possible in MLX. For example:

View File

@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed. :func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we MLX uses lazy evaluation because it has some nice features, some of which we
describe below. describe below.
Transforming Compute Graphs Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -116,7 +116,7 @@ saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass. will be a partial evaluation, computing only the forward pass.

View File

@@ -3,10 +3,10 @@
Conversion to NumPy and Other Frameworks Conversion to NumPy and Other Frameworks
======================================== ========================================
MLX array supports conversion between other frameworks with either: MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_. * The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_. * `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back. Let's convert an array to NumPy and back.
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch PyTorch
------- -------
.. warning:: .. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now. multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products. and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input. gradient with respect to the function's input.

View File

@@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats. MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats .. list-table:: Serialization Formats
:widths: 20 8 25 25 :widths: 20 8 25 25
:header-rows: 1 :header-rows: 1
* - Format * - Format
- Extension - Extension
- Function - Function
- Notes - Notes
* - NumPy * - NumPy
- ``.npy`` - ``.npy``
- :func:`save` - :func:`save`
- Single arrays only - Single arrays only
* - NumPy archive * - NumPy archive
- ``.npz`` - ``.npz``
- :func:`savez` and :func:`savez_compressed` - :func:`savez` and :func:`savez_compressed`
- Multiple arrays - Multiple arrays
* - Safetensors * - Safetensors
- ``.safetensors`` - ``.safetensors``
- :func:`save_safetensors` - :func:`save_safetensors`
- Multiple arrays - Multiple arrays
* - GGUF * - GGUF
- ``.gguf`` - ``.gguf``
- :func:`save_gguf` - :func:`save_gguf`
- Multiple arrays - Multiple arrays
The :func:`load` function will load any of the supported serialization The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of formats. It determines the format from the extensions. The output of
:func:`load` depends on the format. :func:`load` depends on the format.
Here's an example of saving a single array to a file: Here's an example of saving a single array to a file:

View File

@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b`` run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example: without needing to move them from one memory location to another. For example:
.. code-block:: python .. code-block:: python

View File

@@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
@@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext) add_library(mlx_ext)
# Add sources # Add sources
target_sources( target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers # Add include headers
target_include_directories( target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
@@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET mlx_ext_metallib TARGET
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib mlx_ext_metallib
) TITLE
mlx_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib)
endif() endif()
# ----------------------------- Python Bindings ----------------------------- # ----------------------------- Python Bindings -----------------------------
nanobind_add_module( nanobind_add_module(
_ext _ext
NB_STATIC STABLE_ABI LTO NOMINSIZE NB_STATIC
NB_DOMAIN mlx STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp LTO
) NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)

View File

@@ -2,7 +2,7 @@
requires = [ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.24", "cmake>=3.24",
"mlx>=0.17.0", "mlx>=0.18.0",
"nanobind==2.1.0", "nanobind==2.2.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.17.0 mlx>=0.18.1
nanobind==2.1.0 nanobind==2.2.0

View File

@@ -1,26 +1,24 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
@@ -28,17 +26,15 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU) elseif(MLX_BUILD_CPU)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif() endif()
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@@ -23,11 +23,22 @@ void free(Buffer buffer) {
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)}; void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
} }
void CommonAllocator::free(Buffer buffer) { void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr()); std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
} }
Buffer malloc_or_wait(size_t size) { Buffer malloc_or_wait(size_t size) {

View File

@@ -41,6 +41,7 @@ class Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private: private:
CommonAllocator() = default; CommonAllocator() = default;

View File

@@ -95,13 +95,29 @@ void array::detach() {
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
void array::eval() { bool array::is_available() const {
// Ensure the array is ready to be read if (status() == Status::available) {
if (status() == Status::scheduled) { return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait(); event().wait();
set_status(Status::available); set_status(Status::available);
} else if (status() == Status::unscheduled) { }
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this}); mlx::core::eval({*this});
} else {
wait();
} }
} }
@@ -162,8 +178,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags; array_desc_->flags = flags;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset; auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>( auto data_ptr = other.array_desc_->data_ptr;
static_cast<char*>(other.array_desc_->data_ptr) + char_offset); other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(array other) {
@@ -242,25 +260,35 @@ array::ArrayDesc::~ArrayDesc() {
// This calls recursively the destructor and can result in stack overflow, we // This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a // instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2. // max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion; std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) { auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
if (a.array_desc_.use_count() == 1) { std::unordered_map<std::uintptr_t, array> input_map;
for_deletion.push_back(std::move(a.array_desc_)); for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
}
} }
} ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) { while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays // top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector // with inputs have been moved into the vector
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top);
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
} }
} }

View File

@@ -324,6 +324,10 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
} }
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer // Return a copy of the shared pointer
// to the array::Data struct // to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const { std::shared_ptr<Data> data_shared_ptr() const {
@@ -340,11 +344,33 @@ class array {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
} }
enum Status { unscheduled, scheduled, available }; enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
bool is_available() const { // The ouptut of a computation which has been scheduled but `eval_*` has
return status() == Status::available; // not yet been called on the array's primitive. A possible
} // status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const { Status status() const {
return array_desc_->status; return array_desc_->status;

View File

@@ -1,10 +1,8 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);

View File

@@ -18,49 +18,61 @@ void _qmm_t_4_64(
const float* biases, const float* biases,
int M, int M,
int N, int N,
int K) { int K,
int B,
bool batched_w) {
constexpr int bits = 4; constexpr int bits = 4;
constexpr int group_size = 64; constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { int w_els = N * K / pack_factor;
const uint32_t* w_local = w; int g_els = w_els * pack_factor / group_size;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) { for (int i = 0; i < B; i++) {
const simd_float16* x_local = (simd_float16*)x; for (int m = 0; m < M; m++) {
simd_float16 sum = 0; const uint32_t* w_local = w;
for (int k = 0; k < K; k += group_size) { const float* scales_local = scales;
float scale = *scales_local++; const float* biases_local = biases;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) { for (int n = 0; n < N; n++) {
// TODO: vectorize this properly const simd_float16* x_local = (simd_float16*)x;
simd_uint16 wi; simd_float16 sum = 0;
for (int e = 0; e < 2; e++) { for (int k = 0; k < K; k += group_size) {
uint32_t wii = *w_local++; float scale = *scales_local++;
for (int p = 0; p < 8; p++) { float bias = *biases_local++;
wi[e * 8 + p] = wii & bitmask;
wii >>= bits; for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
} }
} simd_float16 wf = simd_float(wi);
simd_float16 wf = simd_float(wi); wf *= scale;
wf *= scale; wf += bias;
wf += bias;
sum += (*x_local) * wf; sum += (*x_local) * wf;
x_local++; x_local++;
}
} }
*result = simd_reduce_add(sum);
result++;
} }
*result = simd_reduce_add(sum); x += K;
result++; }
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
} }
x += K;
} }
} }
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) { if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64( _qmm_t_4_64(
out.data<float>(), out.data<float>(),
x.data<float>(), x.data<float>(),
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(), biases.data<float>(),
M, M,
N, N,
K); K,
B,
batched_w);
} else { } else {
eval(inputs, out); eval(inputs, out);
} }

View File

@@ -33,8 +33,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster * Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0. * implementations for numbers strictly < 0.
*/ */
inline simd_float16 simd_fast_exp(simd_float16 x) { inline simd_float16 simd_fast_exp(simd_float16 x_init) {
x *= 1.442695; // multiply with log_2(e) auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart; simd_float16 ipart, fpart;
simd_int16 epart; simd_int16 epart;
x = simd_clamp(x, -80, 80); x = simd_clamp(x, -80, 80);
@@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
// bitshifting // bitshifting
epart = (simd_int(ipart) + 127) << 23; epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x; // Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

View File

@@ -1,5 +1,4 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER}) set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE) set(CLANG TRUE)
else() else()
@@ -7,72 +6,57 @@ else()
endif() endif()
add_custom_command( add_custom_command(
OUTPUT compiled_preamble.cpp OUTPUT compiled_preamble.cpp
COMMAND /bin/bash COMMAND
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG}
${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh
${CLANG} compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
DEPENDS make_compiled_preamble.sh add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble) add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
) ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if (IOS) if(IOS)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif() endif()

View File

@@ -43,13 +43,15 @@ void set_binary_op_output_data(
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt,
bool donate_with_move = false) { bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -64,7 +66,7 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
@@ -79,13 +81,13 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) { } else if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -100,16 +102,14 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous && if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if ( } else if (
b.is_donatable() && b.flags().row_contiguous && b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -122,19 +122,7 @@ void set_binary_op_output_data(
} }
} }
struct UseDefaultBinaryOp { struct UseDefaultBinaryOp {};
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
struct DefaultVectorScalar { struct DefaultVectorScalar {
@@ -150,18 +138,6 @@ struct DefaultVectorScalar {
a++; a++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -178,18 +154,6 @@ struct DefaultScalarVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -206,204 +170,110 @@ struct DefaultVectorVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims1(const array& a, const array& b, array& out, Op op) { void binary_op_dims(
const T* a_ptr = a.data<T>(); const T* a,
const T* b_ptr = b.data<T>(); const T* b,
U* dst = out.data<U>(); U* out,
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
Op op, Op op,
int stride) { const std::vector<int>& shape,
const T* a_ptr = a.data<T>(); const std::vector<size_t>& a_strides,
const T* b_ptr = b.data<T>(); const std::vector<size_t>& b_strides,
U* dst = out.data<U>(); const std::vector<size_t>& out_strides,
size_t a_idx = 0; int axis) {
size_t b_idx = 0; auto stride_a = a_strides[axis];
for (size_t i = 0; i < a.shape()[0]; i++) { auto stride_b = b_strides[axis];
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); auto stride_out = out_strides[axis];
a_idx += a.strides()[0]; auto N = shape[axis];
b_idx += b.strides()[0];
dst += stride;
}
}
template <typename T, typename U, typename Op> for (int i = 0; i < N; i++) {
void binary_op_dims2(const array& a, const array& b, array& out, Op op) { if constexpr (D > 1) {
const T* a_ptr = a.data<T>(); binary_op_dims<T, U, Op, D - 1, Strided>(
const T* b_ptr = b.data<T>(); a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
U* dst = out.data<U>(); } else {
size_t a_idx = 0; if constexpr (Strided) {
size_t b_idx = 0; op(a, b, out, stride_out);
size_t out_idx = 0; } else {
for (size_t i = 0; i < a.shape()[0]; ++i) { *out = op(*a, *b);
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
} }
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
} }
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; out += stride_out;
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; a += stride_a;
b += stride_b;
} }
} }
template <typename T, typename U, typename Op> template <typename T, typename U, bool Strided, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims( void binary_op_dispatch_dims(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
Op op, Op op,
int dim, int dim,
int stride) { const std::vector<int>& shape,
// Number of dimensions to loop over for vectorized ops const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) { switch (dim) {
case 1: case 1:
binary_op_dims1<T, U, Op>(a, b, out, op, stride); binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims2<T, U, Op>(a, b, out, op, stride); binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
const T* a_ptr = a.data<T>(); ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
const T* b_ptr = b.data<T>(); ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
U* dst = out.data<U>(); size_t stride = out_strides[dim - 4];
for (size_t i = 0; i < out.size(); i += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides()); binary_op_dims<T, U, Op, 3, Strided>(
int b_idx = elem_to_loc(i, b.shape(), b.strides()); a_ptr + a_it.loc,
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); b_ptr + b_it.loc,
dst += stride; out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
} }
} }
@@ -450,29 +320,33 @@ void binary_op(
} }
// General computation so let's try to optimize // General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto& strides = out.strides(); auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
auto leftmost_rc_dim = [&strides](const array& arr) { int d = arr_strides.size() - 1;
int d = arr.ndim() - 1; for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_rc_dim = leftmost_rc_dim(a); auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b); auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after // Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) { auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
int d = arr.ndim() - 1; int d = arr_strides.size() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) { for (; d >= 0 && arr_strides[d] == 0; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_s_dim = leftmost_s_dim(a); auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b); auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = out.ndim(); auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim; int dim = ndim;
@@ -494,27 +368,27 @@ void binary_op(
// Can be sure dim > 0 since otherwise we would have used one of the fully // Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not // contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity. // correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) { if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General; bopt = BinaryOpType::General;
dim = ndim; dim = ndim;
} else {
stride = strides[dim - 1];
} }
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break; break;
default: default:
binary_op_dispatch_dims<T, U>(a, b, out, op); binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break; break;
} }
} }
@@ -531,9 +405,9 @@ void binary_op(
// TODO: The following mess of constexpr evaluations can probably be achieved // TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler? // with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?) // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -554,7 +428,8 @@ void binary_op(
DefaultVectorScalar<T, T, Op>(op), DefaultVectorScalar<T, T, Op>(op),
opvv); opvv);
} }
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp // opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -569,7 +444,8 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv); a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
} }
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp // opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
@@ -585,7 +461,8 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv); a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
} }
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp // opvv was UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op)); a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));

View File

@@ -9,168 +9,43 @@ namespace mlx::core {
namespace { namespace {
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int D>
void binary_op_dims1( void binary_op_dims(
const array& a, const T* a,
const array& b, const T* b,
array& out_a, U* out_a,
array& out_b, U* out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out_a.size(); ++i) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op, Op op,
int stride) { const std::vector<int>& shape,
const T* a_ptr = a.data<T>(); const std::vector<size_t>& a_strides,
const T* b_ptr = b.data<T>(); const std::vector<size_t>& b_strides,
U* dst_a = out_a.data<U>(); const std::vector<size_t>& out_strides,
U* dst_b = out_b.data<U>(); int axis) {
size_t a_idx = 0; auto stride_a = a_strides[axis];
size_t b_idx = 0; auto stride_b = b_strides[axis];
for (size_t i = 0; i < a.shape()[0]; i++) { auto stride_out = out_strides[axis];
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); auto N = shape[axis];
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
template <typename T, typename U, typename Op> for (int i = 0; i < N; i++) {
void binary_op_dims2( if constexpr (D > 1) {
const array& a, binary_op_dims<T, U, Op, D - 1>(
const array& b, a,
array& out_a, b,
array& out_b, out_a,
Op op) { out_b,
const T* a_ptr = a.data<T>(); op,
const T* b_ptr = b.data<T>(); shape,
U* dst_a = out_a.data<U>(); a_strides,
U* dst_b = out_b.data<U>(); b_strides,
size_t a_idx = 0; out_strides,
size_t b_idx = 0; axis + 1);
size_t out_idx = 0; } else {
for (size_t i = 0; i < a.shape()[0]; ++i) { std::tie(*out_a, *out_b) = op(*a, *b);
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
} }
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; a += stride_a;
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; b += stride_b;
} out_a += stride_out;
} out_b += stride_out;
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
} }
} }
@@ -181,352 +56,160 @@ void binary_op_dispatch_dims(
array& out_a, array& out_a,
array& out_b, array& out_b,
Op op) { Op op) {
switch (out_a.ndim()) { auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1: case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op); binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op); binary_op_dims<T, U, Op, 2>(
return; a_ptr,
case 3: b_ptr,
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op); out_a_ptr,
return; out_b_ptr,
case 4: op,
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op); shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
const T* a_ptr = a.data<T>(); ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
const T* b_ptr = b.data<T>(); ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
U* dst_a = out_a.data<U>(); size_t stride = out_strides[ndim - 3];
U* dst_b = out_b.data<U>(); for (size_t elem = 0; elem < a.size(); elem += stride) {
for (size_t i = 0; i < out_a.size(); i++) { binary_op_dims<T, U, Op, 2>(
int a_idx = elem_to_loc(i, a.shape(), a.strides()); a_ptr + a_it.loc,
int b_idx = elem_to_loc(i, b.shape(), b.strides()); b_ptr + b_it.loc,
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]); out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
} }
} }
template <typename T, typename U, typename Op> template <typename T, typename U = T, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
array& out_a, std::vector<array>& outputs,
array& out_b, Op op) {
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt); set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) = std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
op(*a.data<T>(), *b.data<T>()); } else if (bopt == BinaryOpType::ScalarVector) {
return; for (size_t i = 0; i < b.size(); ++i) {
} std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
// The full computation is scalar vector so delegate to the op out_b_ptr++;
if (bopt == BinaryOpType::ScalarVector) { b_ptr++;
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
} }
return d + 1; } else if (bopt == BinaryOpType::VectorScalar) {
}; for (size_t i = 0; i < a.size(); ++i) {
auto a_rc_dim = leftmost_rc_dim(a); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
auto b_rc_dim = leftmost_rc_dim(b); out_a_ptr++;
out_b_ptr++;
// Get the left-most dim such that the array is a broadcasted "scalar" after a_ptr++;
auto leftmost_s_dim = [](const array& arr) { }
int d = arr.ndim() - 1; } else { // VectorVector
for (; d >= 0 && arr.strides()[d] == 0; d--) { for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
} }
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
} }
} }
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV> template <typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary( void binary(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, std::vector<array>& outputs,
Ops... ops) { Op op) {
switch (outputs[0].dtype()) { switch (outputs[0].dtype()) {
case bool_: case bool_:
binary_op<bool>(a, b, outputs, ops...); binary_op<bool>(a, b, outputs, op);
break; break;
case uint8: case uint8:
binary_op<uint8_t>(a, b, outputs, ops...); binary_op<uint8_t>(a, b, outputs, op);
break; break;
case uint16: case uint16:
binary_op<uint16_t>(a, b, outputs, ops...); binary_op<uint16_t>(a, b, outputs, op);
break; break;
case uint32: case uint32:
binary_op<uint32_t>(a, b, outputs, ops...); binary_op<uint32_t>(a, b, outputs, op);
break; break;
case uint64: case uint64:
binary_op<uint64_t>(a, b, outputs, ops...); binary_op<uint64_t>(a, b, outputs, op);
break; break;
case int8: case int8:
binary_op<int8_t>(a, b, outputs, ops...); binary_op<int8_t>(a, b, outputs, op);
break; break;
case int16: case int16:
binary_op<int16_t>(a, b, outputs, ops...); binary_op<int16_t>(a, b, outputs, op);
break; break;
case int32: case int32:
binary_op<int32_t>(a, b, outputs, ops...); binary_op<int32_t>(a, b, outputs, op);
break; break;
case int64: case int64:
binary_op<int64_t>(a, b, outputs, ops...); binary_op<int64_t>(a, b, outputs, op);
break; break;
case float16: case float16:
binary_op<float16_t>(a, b, outputs, ops...); binary_op<float16_t>(a, b, outputs, op);
break; break;
case float32: case float32:
binary_op<float>(a, b, outputs, ops...); binary_op<float>(a, b, outputs, op);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, ops...); binary_op<bfloat16_t>(a, b, outputs, op);
break; break;
case complex64: case complex64:
binary_op<complex64_t>(a, b, outputs, ops...); binary_op<complex64_t>(a, b, outputs, op);
break; break;
} }
} }

View File

@@ -2,46 +2,12 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) { void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that // Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric: // the matrix should be symmetric:
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization. // Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N); int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite // TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how // because throwing an error would result in a crash. If we figure out how

View File

@@ -156,8 +156,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
} }
// Firstly let's collapse all the contiguous dimensions of the input // Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so // If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check. // let's check.

View File

@@ -4,6 +4,8 @@
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
@@ -12,22 +14,7 @@
namespace mlx::core { namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in struct CompilerCache {
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib { struct DLib {
DLib(const std::string& libname) { DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW); lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -44,15 +31,41 @@ void* compile(
void* lib; void* lib;
}; };
// Statics to cache compiled libraries and functions // Statics to cache compiled libraries and functions
static std::list<DLib> libs; std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels; std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) { std::shared_mutex mtx;
return it->second; };
}
if (source_code.empty()) { static CompilerCache cache{};
return nullptr;
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
} }
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name; std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255 // Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,8 +103,8 @@ void* compile(
source_file.close(); source_file.close();
std::ostringstream build_command; std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << " -o " << shared_lib_path; << source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str(); std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str()); auto return_code = system(build_command_str.c_str());
if (return_code) { if (return_code) {
@@ -103,10 +116,10 @@ void* compile(
} }
// load library // load library
libs.emplace_back(shared_lib_path); cache.libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -114,7 +127,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
kernels.insert({kernel_name, fun}); cache.kernels.insert({kernel_name, fun});
return fun; return fun;
} }
@@ -316,10 +329,7 @@ void Compiled::eval_cpu(
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name); auto fn_ptr = compile(kernel_name, [&]() {
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +344,8 @@ void Compiled::eval_cpu(
ndim); ndim);
// Close extern "C" // Close extern "C"
kernel << "}" << std::endl; kernel << "}" << std::endl;
return kernel.str();
// Compile and get function pointer });
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false); inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,13 +3,8 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -684,6 +679,32 @@ void dispatch_slow_conv_3D(
// Explicit gemm conv // Explicit gemm conv
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T>
void flip_spatial_dims_inplace(array& wt) {
T* x = wt.data<T>();
size_t out_channels = wt.shape(0);
size_t in_channels = wt.shape(-1);
// Calculate the total size of the spatial dimensions
int spatial_size = 1;
for (int d = 1; d < wt.ndim() - 1; ++d) {
spatial_size *= wt.shape(d);
}
for (size_t i = 0; i < out_channels; i++) {
T* top = x + i * spatial_size * in_channels;
T* bottom =
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
for (size_t j = 0; j < spatial_size / 2; j++) {
for (size_t k = 0; k < in_channels; k++) {
std::swap(top[k], bottom[k]);
}
top += in_channels;
bottom -= in_channels;
}
}
}
void explicit_gemm_conv_1D_cpu( void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
@@ -910,7 +931,8 @@ void explicit_gemm_conv_ND_cpu(
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) { const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>( const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
@@ -1000,6 +1022,14 @@ void explicit_gemm_conv_ND_cpu(
copy(wt, gemm_wt, ctype); copy(wt, gemm_wt, ctype);
} }
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy(gemm_wt, gemm_wt_, CopyType::Vector);
flip_spatial_dims_inplace<float>(gemm_wt_);
gemm_wt = gemm_wt_;
}
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
@@ -1042,10 +1072,15 @@ void conv_1D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation); in, wt, out, padding, wt_strides, wt_dilation);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
@@ -1060,6 +1095,13 @@ void conv_2D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1073,6 +1115,14 @@ void conv_3D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1125,7 +1175,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
else { else {
std::ostringstream msg; std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports" msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2 << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions"; << " spatial dimensions";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -26,292 +26,117 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT, typename stride_t> template <typename SrcT, typename DstT, typename StrideT, int D>
void copy_general_dim1( inline void copy_dims(
const array& src, const SrcT* src,
array& dst, DstT* dst,
const std::vector<int>& data_shape, const std::vector<int>& shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
int64_t i_offset) { const std::vector<StrideT>& o_strides,
const SrcT* src_ptr = src.data<SrcT>(); int axis) {
DstT* dst_ptr = dst.data<DstT>(); auto stride_src = i_strides[axis];
stride_t src_idx = i_offset; auto stride_dst = o_strides[axis];
stride_t dst_idx = 0; auto N = shape[axis];
for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[0];
}
}
template <typename SrcT, typename DstT> for (int i = 0; i < N; i++) {
inline void copy_general_dim1(const array& src, array& dst) { if constexpr (D > 1) {
return copy_general_dim1<SrcT, DstT, size_t>( copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, src.shape(), src.strides(), 0); src, dst, shape, i_strides, o_strides, axis + 1);
} } else {
*dst = static_cast<DstT>(*src);
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
} }
src_idx += i_strides[0] - i_strides[1] * data_shape[1]; src += stride_src;
dst += stride_dst;
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename StrideT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
if constexpr (D > 1) {
int axis = data_shape.size() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
}
} else {
int axis = data_shape.size() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
dst_ptr += stride_dst;
}
}
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<StrideT>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset) { int64_t o_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims( if (data_shape.empty()) {
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides}); auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
switch (new_shape.size()) { auto dst_ptr = dst.data<DstT>() + o_offset;
case 1: *dst_ptr = val;
copy_general_general_dims<SrcT, DstT, stride_t, 1>( return;
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
} }
auto [shape, strides] = collapse_contiguous_dims(
int size = std::accumulate( data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>()); auto src_ptr = src.data<SrcT>() + i_offset;
for (int i = 0; i < src.size(); i += size) { auto dst_ptr = dst.data<DstT>() + o_offset;
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]); int ndim = shape.size();
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]); if (ndim == 1) {
copy_general_general_dims<SrcT, DstT, stride_t, 5>( copy_dims<SrcT, DstT, StrideT, 1>(
src, src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
dst, return;
new_shape, } else if (ndim == 2) {
new_strides[0], copy_dims<SrcT, DstT, StrideT, 2>(
new_strides[1], src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
src_offset, return;
dst_offset); } else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>( copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
} }
template <typename SrcT, typename DstT, typename StrideT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
0,
0);
}
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) { switch (ctype) {
@@ -326,6 +151,7 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...); copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
} }
} }
@@ -426,7 +252,7 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
return copy_inplace_dispatch(src, dst, ctype); copy_inplace_dispatch(src, dst, ctype);
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype) {
@@ -456,20 +282,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename stride_t> template <typename StrideT>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<StrideT>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype) {
switch (ctype) { switch (ctype) {
case CopyType::General: case CopyType::General:
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
return copy_inplace_dispatch( copy_inplace_dispatch(
src, src,
dst, dst,
ctype, ctype,
@@ -478,10 +304,10 @@ void copy_inplace(
o_strides, o_strides,
i_offset, i_offset,
o_offset); o_offset);
break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::Vector: case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype); copy_inplace_dispatch(src, dst, ctype);
} }
} }

View File

@@ -1,14 +1,10 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace { namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -81,11 +80,18 @@ void gather(
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0; size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]); auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax)); offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]); src_idx += (idx_val * src.strides()[ax]);
@@ -99,9 +105,10 @@ void gather(
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; src_it.step();
} }
src_it.reset();
} }
} }
} }
@@ -223,21 +230,29 @@ void scatter(
update_size *= us; update_size *= us;
} }
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < nind; ++j) { for (int j = 0; j < nind; ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]); auto idx_loc = its[j].loc;
its[j].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax)); offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]); out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates); op(updates.data<InT>()[update_it.loc],
auto out_loc = elem_to_loc(j, update_shape, out.strides()); out.data<InT>() + out_offset + out_it.loc);
op(updates.data<InT>()[update_loc], update_it.step();
out.data<InT>() + out_offset + out_loc); out_it.step();
} }
out_it.reset();
update_it.reset();
} }
} }

View File

@@ -2,39 +2,19 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) { int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info; int info;
MLX_LAPACK_FUNC(strtri)
#ifdef LAPACK_FORTRAN_STRLEN_END (
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
/* uplo = */ &uplo, /* uplo = */ &uplo,
/* diag = */ &diag, /* diag = */ &diag,
/* N = */ &N, /* N = */ &N,
/* a = */ matrix, /* a = */ matrix,
/* lda = */ &N, /* lda = */ &N,
/* info = */ &info); /* info = */ &info);
#endif
return info; return info;
} }

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h>
#include <lapack.h> #include <lapack.h>
#endif #endif

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
EOM EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() { const char* get_kernel_preamble() {

View File

@@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -295,6 +295,13 @@ struct Floor {
} }
}; };
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log { struct Log {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
@@ -337,6 +344,13 @@ struct Negative {
} }
}; };
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round { struct Round {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {

View File

@@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype); copy(in, out, ctype);
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) { void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) { void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -406,16 +414,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape()); copy_inplace(in, out, CopyType::General);
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
@@ -612,11 +611,18 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
strides[i] /= obytes; strides[i] /= obytes;
} }
out.copy_shared_buffer( out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes); in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else { } else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {}); auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General); if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags(); auto flags = out.flags();
flags.contiguous = true; flags.contiguous = true;

View File

@@ -2,14 +2,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>

View File

@@ -201,55 +201,61 @@ void _qmm_dispatch(
int group_size, int group_size,
bool transposed_w) { bool transposed_w) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
switch (x.dtype()) { int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
case float32: int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
_qmm_dispatch_typed<float>(
out.data<float>(), int batch_size = x.size() / x.shape(-1) / x.shape(-2);
x.data<float>(), for (int i = 0; i < batch_size; i++) {
w.data<uint32_t>(), switch (x.dtype()) {
scales.data<float>(), case float32:
biases.data<float>(), _qmm_dispatch_typed<float>(
M, out.data<float>() + i * M * N,
N, x.data<float>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<float>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<float>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
case float16: K,
_qmm_dispatch_typed<float16_t>( bits,
out.data<float16_t>(), group_size,
x.data<float16_t>(), transposed_w);
w.data<uint32_t>(), break;
scales.data<float16_t>(), case float16:
biases.data<float16_t>(), _qmm_dispatch_typed<float16_t>(
M, out.data<float16_t>() + i * M * N,
N, x.data<float16_t>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
case bfloat16: K,
_qmm_dispatch_typed<bfloat16_t>( bits,
out.data<bfloat16_t>(), group_size,
x.data<bfloat16_t>(), transposed_w);
w.data<uint32_t>(), break;
scales.data<bfloat16_t>(), case bfloat16:
biases.data<bfloat16_t>(), _qmm_dispatch_typed<bfloat16_t>(
M, out.data<bfloat16_t>() + i * M * N,
N, x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
default: K,
throw std::invalid_argument( bits,
"[quantized_matmul] only floating types are supported"); group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
} }
} }

View File

@@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
int axis_size = out.shape(axis); int axis_size = out.shape(axis);
// Perform sorting in place // Perform sorting in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); T* data_ptr = out.data<T>() + src_it.loc;
T* data_ptr = out.data<T>() + loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed); std::stable_sort(st, ed);
src_it.step();
} }
} }
@@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
// Perform sorting // Perform sorting
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides); const T* data_ptr = in.data<T>() + in_it.loc;
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides); IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
const T* data_ptr = in.data<T>() + in_loc; in_it.step();
IdxT* idx_ptr = out.data<IdxT>() + out_loc; out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size); StridedIterator ed_(idx_ptr, out_stride, axis_size);
@@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); T* data_ptr = out.data<T>() + src_it.loc;
T* data_ptr = out.data<T>() + loc; src_it.step();
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
@@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape(); auto in_remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto in_remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis); in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition // Perform partition
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); const T* data_ptr = in.data<T>() + in_it.loc;
const T* data_ptr = in.data<T>() + loc; IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + loc; in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, axis_stride, 0); StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size); StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0); StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, axis_stride, kth); StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, axis_stride, axis_size); StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * axis_stride]; auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -2,7 +2,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h" #include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -41,7 +41,7 @@ void set_ternary_op_output_data(
TernaryOpType topt, TernaryOpType topt,
bool donate_with_move = false) { bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) { auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) { if (is_donatable(x, out)) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(x); out.move_shared_buffer(x);
} else { } else {
@@ -71,128 +71,46 @@ void set_ternary_op_output_data(
break; break;
} }
} }
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
template <typename T1, typename T2, typename T3, typename U, typename Op> for (int i = 0; i < N; i++) {
void ternary_op_dims1( if constexpr (D > 1) {
const array& a, ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
const array& b, a,
const array& c, b,
array& out, c,
Op op) { out,
const T1* a_ptr = a.data<T1>(); op,
const T2* b_ptr = b.data<T2>(); shape,
const T3* c_ptr = c.data<T3>(); a_strides,
b_strides,
U* dst = out.data<U>(); c_strides,
size_t a_idx = 0; out_strides,
size_t b_idx = 0; axis + 1);
size_t c_idx = 0; } else {
for (size_t i = 0; i < out.size(); ++i) { *out = op(*a, *b, *c);
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
} }
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; a += stride_a;
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; b += stride_b;
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; c += stride_c;
} out += stride_out;
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
} }
} }
@@ -203,30 +121,69 @@ void ternary_op_dispatch_dims(
const array& c, const array& c,
array& out, array& out,
Op op) { Op op) {
switch (out.ndim()) { auto [shape, strides] = collapse_contiguous_dims(
case 1: a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op); const auto& a_strides = strides[0];
return; const auto& b_strides = strides[1];
case 2: const auto& c_strides = strides[2];
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op); const auto& out_strides = strides[3];
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
const T1* a_ptr = a.data<T1>(); const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>(); const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>(); const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>(); U* out_ptr = out.data<T3>();
for (size_t i = 0; i < out.size(); i++) { int ndim = shape.size();
int a_idx = elem_to_loc(i, a.shape(), a.strides()); switch (ndim) {
int b_idx = elem_to_loc(i, b.shape(), b.strides()); case 1:
int c_idx = elem_to_loc(i, c.shape(), c.strides()); ternary_op_dims<T1, T2, T3, U, Op, 1>(
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
} }
} }
@@ -243,10 +200,21 @@ void ternary_op(
// The full computation is scalar-scalar-scalar so we call the base op once. // The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>()); *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return; } else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
} }
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
} }
} // namespace } // namespace

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
namespace { namespace {
void set_unary_output_data(const array& in, array& out) { void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (is_donatable(in, out)) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto size = in.data_size(); auto size = in.data_size();
@@ -24,22 +24,36 @@ void set_unary_output_data(const array& in, array& out) {
} }
} }
template <typename T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) { void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
if (a.flags().contiguous) { if (a.flags().contiguous) {
set_unary_output_data(a, out); set_unary_output_data(a, out);
T* dst = out.data<T>(); U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) { for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]); dst[i] = op(a_ptr[i]);
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>(); U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) { size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
// TODO this is super inefficient, need to fix. size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
int a_idx = elem_to_loc(i, a.shape(), a.strides()); if (a.ndim() <= 1) {
dst[i] = op(a_ptr[a_idx]); unary_op(a_ptr, dst, op, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
} }
} }
} }

View File

@@ -0,0 +1,138 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<StrideT>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
if (shape[i] != 1) {
to_collapse.push_back(i);
}
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
};
if (i == to_collapse.size()) {
break;
}
int current_shape = shape[to_collapse[i]];
int k = i;
while (to_collapse[++k] != -1) {
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
}
if (!shape.empty() && out_shape.empty()) {
out_shape.push_back(1);
for (auto& out_stride : out_strides) {
out_stride.push_back(0);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (shape[i] == 1) {
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,12 +8,12 @@
namespace mlx::core { namespace mlx::core {
template <typename stride_t> template <typename StrideT>
inline stride_t elem_to_loc( inline StrideT elem_to_loc(
int elem, int elem,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<stride_t>& strides) { const std::vector<StrideT>& strides) {
stride_t loc = 0; StrideT loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i]; loc += q_and_r.rem * strides[i];
@@ -29,9 +29,9 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
template <typename stride_t> template <typename StrideT>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) { std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1); std::vector<StrideT> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) { for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i]; strides[i - 1] = strides[i] * shape[i];
} }
@@ -44,58 +44,26 @@ std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
// //
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
template <typename stride_t> std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims( collapse_contiguous_dims(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) { const std::vector<std::vector<int64_t>>& strides,
// Make a vector that has axes separated with -1. Collapse all axes between int64_t size_cap = std::numeric_limits<int32_t>::max());
// -1. std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
std::vector<int> to_collapse; collapse_contiguous_dims(
if (shape.size() > 0) { const std::vector<int>& shape,
to_collapse.push_back(0); const std::vector<std::vector<size_t>>& strides,
for (int i = 1; i < shape.size(); i++) { size_t size_cap = std::numeric_limits<int32_t>::max());
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) { collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides; std::vector<std::vector<size_t>> strides;
for (auto& x : xs) { for (auto& x : xs) {
strides.emplace_back(x.strides()); strides.emplace_back(x.strides());
} }
return collapse_contiguous_dims(xs[0].shape(), strides); return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
} }
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>> template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
@@ -105,36 +73,84 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
} }
// The single array version of the above. // The single array version of the above.
inline std::tuple<std::vector<int>, std::vector<size_t>> std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
collapse_contiguous_dims(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<size_t>& strides) { const std::vector<int64_t>& strides,
std::vector<int> collapsed_shape; int64_t size_cap = std::numeric_limits<int32_t>::max());
std::vector<size_t> collapsed_strides; std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());
if (shape.size() > 0) { template <typename StrideT>
collapsed_shape.push_back(shape[0]); struct ContiguousIterator {
collapsed_strides.push_back(strides[0]); inline void step() {
for (int i = 1; i < shape.size(); i++) { int dims = shape_.size();
if (strides[i] * shape[i] != collapsed_strides.back() || if (dims == 0) {
collapsed_shape.back() * static_cast<size_t>(shape[i]) > return;
std::numeric_limits<int>::max()) { }
collapsed_shape.push_back(shape[i]); int i = dims - 1;
collapsed_strides.push_back(strides[i]); while (pos_[i] == (shape_[i] - 1) && i > 0) {
} else { pos_[i] = 0;
collapsed_shape.back() *= shape[i]; loc -= (shape_[i] - 1) * strides_[i];
collapsed_strides.back() = strides[i]; i--;
} }
pos_[i]++;
loc += strides_[i];
}
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
} }
} }
return std::make_tuple(collapsed_shape, collapsed_strides); void reset() {
} loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
template <typename stride_t> ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
StrideT loc{0};
private:
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
};
template <typename StrideT>
inline auto check_contiguity( inline auto check_contiguity(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<stride_t>& strides) { const std::vector<StrideT>& strides) {
size_t no_broadcast_data_size = 1; size_t no_broadcast_data_size = 1;
size_t f_stride = 1; size_t f_stride = 1;
size_t b_stride = 1; size_t b_stride = 1;
@@ -155,4 +171,11 @@ inline auto check_contiguity(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous); no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
} }
inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,99 +1,56 @@
function(make_jit_source SRC_FILE) function(make_jit_source SRC_FILE)
# This function takes a metal header file, # This function takes a metal header file, runs the C preprocessesor on it,
# runs the C preprocessesor on it, and makes # and makes the processed contents available as a string in a C++ function
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}() # mlx::core::metal::${SRC_NAME}()
# #
# To use the function, declare it in jit/includes.h and # To use the function, declare it in jit/includes.h and include
# include jit/includes.h. # jit/includes.h.
# #
# Additional arguments to this function are treated as dependencies # Additional arguments to this function are treated as dependencies in the
# in the Cmake build system. # Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME) get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command( add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash COMMAND
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${CMAKE_C_COMPILER} ${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
${SRC_FILE}
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME}) add_dependencies(mlx ${SRC_NAME})
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source( make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
utils make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)
make_jit_source( make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(scatter) make_jit_source(scatter)
make_jit_source(gather) make_jit_source(gather)
make_jit_source(hadamard) make_jit_source(hadamard)
if (MLX_METAL_JIT) if(MLX_METAL_JIT)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange) make_jit_source(arange)
make_jit_source(copy) make_jit_source(copy)
make_jit_source(unary) make_jit_source(unary)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source( make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
make_jit_source(sort) make_jit_source(sort)
make_jit_source( make_jit_source(
reduce reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_all.h kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
kernels/reduction/reduce_init.h
)
make_jit_source( make_jit_source(
steel/gemm/gemm steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h
kernels/steel/utils.h kernels/steel/gemm/mma.h kernels/steel/gemm/params.h
kernels/steel/gemm/loader.h kernels/steel/gemm/transforms.h)
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source( make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -104,63 +61,52 @@ if (MLX_METAL_JIT)
kernels/steel/conv/params.h kernels/steel/conv/params.h
kernels/steel/conv/loader.h kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h kernels/steel/conv/loaders/loader_channel_n.h)
) make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source( make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
steel/conv/kernels/steel_conv kernels/steel/conv/loaders/loader_general.h)
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized) make_jit_source(quantized)
make_jit_source(gemv_masked) make_jit_source(gemv_masked)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif() endif()
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
)
if (NOT MLX_METAL_PATH) if(NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions( target_compile_definitions(mlx
mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]); auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ = block_limit_ =
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()), static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_); block_limit_);
max_pool_size_ = block_limit_; max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
} }
size_t MetalAllocator::set_cache_limit(size_t limit) { size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit; return limit;
}; };
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers // Metal doesn't like empty buffers
if (size == 0) { if (size == 0) {
@@ -205,7 +215,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked; res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock(); lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
lk.lock(); lk.lock();
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
@@ -241,16 +254,14 @@ void MetalAllocator::free(Buffer buffer) {
} }
} }
size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will // By creating the |allocator_| on heap, the destructor of MetalAllocator
// not be called on exit and all the buffers will be leaked. This is necessary // will not be called on exit and buffers in the cache will be leaked. This
// because releasing buffers can take more than 30sec when the program holds a // can save some time at program exit.
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
static MetalAllocator* allocator_ = new MetalAllocator; static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_; return *allocator_;
} }
@@ -261,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed); return allocator().set_memory_limit(limit, relaxed);
} }
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return allocator().get_active_memory();
} }

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal { namespace mlx::core::metal {
@@ -56,6 +57,7 @@ class MetalAllocator : public allocator::Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };
@@ -71,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache(); void clear_cache();
private: private:
@@ -81,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator // Caching allocator
BufferCache buffer_cache_; BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats // Allocation stats
size_t block_limit_; size_t block_limit_;
size_t gc_limit_; size_t gc_limit_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
size_t max_pool_size_; size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true}; bool relaxed_{true};
std::mutex mutex_; std::mutex mutex_;

View File

@@ -19,14 +19,13 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name( std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
const std::string& op, const std::string& op,
const array& a, const array& a,
bool use_2d, bool use_2d,
int ndim) { int ndim,
int work_per_thread) {
std::ostringstream kname; std::ostringstream kname;
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
@@ -43,14 +42,17 @@ std::string get_kernel_name(
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname << "g"; kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) { if (ndim <= 3) {
kname << ndim; kname << ndim;
} else { } else {
kname << "n"; kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
} }
break; break;
} }
kname << op << type_to_name(a); kname << "_" << op << type_to_name(a);
return kname.str(); return kname.str();
} }
@@ -69,52 +71,67 @@ void binary_op_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto maybe_collapse = [bopt, &a, &b, &out]() {
auto& strides_a = strides[0]; if (bopt == BinaryOpType::General) {
auto& strides_b = strides[1]; auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_out = strides[2]; return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else {
std::vector<size_t> e;
return std::make_tuple(std::vector<int>{}, e, e, e);
}
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size()); auto ndim = shape.size();
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = auto kernel = outputs.size() == 2
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op); ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output // otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr; bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array( compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_output_array(outputs[0], 2); compute_encoder.set_output_array(outputs[0], arg_idx++);
compute_encoder.set_output_array(outputs[1], 3); if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads // Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
@@ -125,9 +142,8 @@ void binary_op_gpu_inplace(
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) : MTL::Size(nthreads, 1, 1);
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
@@ -164,72 +180,8 @@ void binary_op_gpu_inplace(
array& out, array& out,
const std::string& op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; std::vector<array> outputs = {out};
auto& b = inputs[1]; binary_op_gpu_inplace(inputs, outputs, op, s);
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} }
void binary_op_gpu( void binary_op_gpu(

View File

@@ -13,6 +13,8 @@
namespace mlx::core { namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel( inline void build_kernel(
std::ostream& os, std::ostream& os,
const std::string& kernel_name, const std::string& kernel_name,
@@ -22,7 +24,9 @@ inline void build_kernel(
const std::unordered_set<uintptr_t>& constant_ids, const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous, bool contiguous,
int ndim, int ndim,
bool dynamic_dims) { bool dynamic_dims,
bool use_big_index = false,
int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous // All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape(); auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides(); auto output_strides = outputs[0].strides();
@@ -37,8 +41,8 @@ inline void build_kernel(
int cnt = 0; int cnt = 0;
// Start the kernel // Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(" << std::endl; << "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments // Add the input arguments
for (auto& x : inputs) { for (auto& x : inputs) {
@@ -52,11 +56,11 @@ inline void build_kernel(
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) { if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl; << " [[buffer(" << cnt++ << ")]],\n";
} else { } else {
add_indices = true; add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl; << " [[buffer(" << cnt++ << ")]],\n";
} }
} }
@@ -68,52 +72,37 @@ inline void build_kernel(
// Add the output arguments // Add the output arguments
for (auto& x : outputs) { for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* " os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl; << namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++ os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl << ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]]," << " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
<< std::endl;
} }
if (dynamic_dims) { if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]]," os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
<< std::endl;
} }
// The thread index in the whole grid // The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {" << std::endl << " uint3 grid [[threads_per_grid]]) {\n";
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
// Extract the indices per axis to individual uints if we have arrays that if (use_big_index) {
// are broadcasted or transposed // This is only used for contiguous kernels which don't have
if (add_indices) { // a third grid dimension
if (!dynamic_dims) { os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
if (ndim == 1) { } else if (work_per_thread > 1) {
os << " uint index_0 = pos.x;" << std::endl; os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
} else if (ndim == 2) { << " int xshape = output_shape["
os << " uint index_0 = pos.y;" << std::endl << (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " uint index_1 = pos.x;" << std::endl; << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else if (ndim == 3) { } else {
os << " uint index_0 = pos.z;" << std::endl os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
} }
// Read the inputs in tmps // Read constant / contiguous inputs in tmps
int nc_in_count = 0; std::vector<array> nc_inputs;
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i]; auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -123,56 +112,117 @@ inline void build_kernel(
os << " auto tmp_" << xname << " = static_cast<" os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">("; << get_type_string(x.dtype()) << ">(";
print_constant(os, x); print_constant(os, x);
os << ");" << std::endl; os << ");\n";
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl; << xname << "[0];\n";
} else if (contiguous) { } else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl; << xname << "[index];\n";
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else { } else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " nc_inputs.push_back(x);
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
} }
} }
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
if (ndim == 1) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
} else if (ndim == 2) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
} else if (ndim == 3) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os << " uint zpos = pos.z;\n";
if (dynamic_dims) {
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
}
os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os << " index_" << xname << " += ";
if (dynamic_dims) {
os << "l * in_strides[" << i << " * ndim + d];\n";
} else {
os << "l * in_strides[" << i * ndim << " + d];\n";
}
}
os << " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
}
// Actually write the computation // Actually write the computation
for (auto& x : tape) { for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = "; << " = ";
if (is_static_cast(x.primitive())) { if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl; << namer.get_name(x.inputs()[0]) << ");\n";
} else { } else {
x.primitive().print(os); x.primitive().print(os);
os << "()("; os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
} }
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl; os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
} }
} }
// Write the outputs from tmps // Write the outputs from tmps
for (auto& x : outputs) { for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl; << ";\n";
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else {
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
}
}
os << " index++;\n }\n";
} }
// Finish the kernel // Finish the kernel
os << "}" << std::endl; os << "}\n";
if (cnt > 31) { if (cnt > 31) {
std::ostringstream msg; std::ostringstream msg;
@@ -195,10 +245,7 @@ void Compiled::eval_gpu(
// Get the kernel if someone else built it already // Get the kernel if someone else built it already
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_); auto lib = d.get_library(kernel_lib_, [&]() {
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops() kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops(); << metal::ternary_ops();
@@ -212,6 +259,17 @@ void Compiled::eval_gpu(
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false); /* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) { for (int i = 1; i < 8; i++) {
build_kernel( build_kernel(
kernel, kernel,
@@ -222,7 +280,9 @@ void Compiled::eval_gpu(
constant_ids_, constant_ids_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false); /* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
} }
build_kernel( build_kernel(
kernel, kernel,
@@ -233,10 +293,11 @@ void Compiled::eval_gpu(
constant_ids_, constant_ids_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true); /* dynamic_dims = */ true,
/* use_big_index = */ false,
lib = d.get_library(kernel_lib_, kernel.str()); /* work_per_thread = */ WORK_PER_THREAD);
} return kernel.str();
});
// Figure out which kernel we are using // Figure out which kernel we are using
auto& output_shape = outputs[0].shape(); auto& output_shape = outputs[0].shape();
@@ -285,7 +346,16 @@ void Compiled::eval_gpu(
initial_strides.push_back(std::move(xstrides)); initial_strides.push_back(std::move(xstrides));
} }
std::tie(shape, strides) = std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides); collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
} }
// Get the kernel from the lib // Get the kernel from the lib
@@ -298,6 +368,8 @@ void Compiled::eval_gpu(
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
} else if (use_2d) {
kernel_name += "_big";
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -348,8 +420,10 @@ void Compiled::eval_gpu(
// Launch the kernel // Launch the kernel
if (contiguous) { if (contiguous) {
size_t nthreads = outputs[0].size(); size_t nthreads = outputs[0].data_size();
MTL::Size grid_dims(nthreads, 1, 1); MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
@@ -357,11 +431,18 @@ void Compiled::eval_gpu(
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { int pow2;
throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }

View File

@@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu(
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size()); wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped}; std::vector<array> copies = {in_unfolded};
return steel_matmul( return steel_matmul(
s, s,
d, d,
@@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu(
copy_gpu(wt_view, wt_transpose, CopyType::General, s); copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_conv_groups( return steel_matmul_regular(
s, s,
d, d,
/*a = */ in_unfolded, /* a = */ in_unfolded,
/*b = */ wt_transpose, /* b = */ wt_transpose,
/*c = */ out, /* c = */ out,
/*M = */ implicit_M, /* M = */ implicit_M,
/*N = */ implicit_N, /* N = */ implicit_N,
/*K = */ implicit_K, /* K = */ implicit_K,
/*a_cols = */ implicit_K * groups, /* batch_size_out = */ groups,
/*b_cols = */ implicit_K, /* a_cols = */ implicit_K * groups,
/*out_cols = */ implicit_N * groups, /* b_cols = */ implicit_K,
/*a_transposed = */ false, /* out_cols = */ implicit_N * groups,
/*b_transposed = */ true, /* a_transposed = */ false,
/* groups = */ groups, /* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies); /*copies = */ copies);
} }
@@ -552,7 +557,7 @@ void winograd_conv_2D_gpu(
// Fill with zeros // Fill with zeros
array zero_arr = array(0, in.dtype()); array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s); fill_gpu(zero_arr, in_padded, s);
copies_w.push_back(zero_arr); copies_w.push_back(zero_arr);
// Pick input slice from padded // Pick input slice from padded
@@ -571,7 +576,6 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded_slice); copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded); copies_w.push_back(in_padded);
copies_w.push_back(zero_arr);
MLXConvParams<2> conv_params_updated{ MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0), /* const int N = */ in_padded.shape(0),
@@ -911,15 +915,11 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// Throw error // Throw error
else { else {
throw std::invalid_argument( throw std::invalid_argument(
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions."); "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
} }
// Clear copies // Record copies
if (copies.size() > 0) { d.add_temporaries(std::move(copies), s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <iostream>
#include <sstream> #include <sstream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
@@ -10,7 +11,7 @@
namespace mlx::core { namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
@@ -59,13 +60,25 @@ void copy_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims( auto maybe_collapse =
data_shape, std::vector{strides_in_pre, strides_out_pre}); [ctype, &data_shape, &strides_in_pre, &strides_out_pre]() {
auto& strides_in_ = strides[0]; if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
auto& strides_out_ = strides[1]; auto [shape, strides] = collapse_contiguous_dims(
data_shape,
std::vector{strides_in_pre, strides_out_pre},
/* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]);
} else {
std::vector<stride_t> e;
return std::make_tuple(std::vector<int>{}, e, e);
}
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool use_2d = out.data_size() > UINT32_MAX; bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name; std::string kernel_name;
{ {
std::ostringstream kname; std::ostringstream kname;
@@ -83,9 +96,13 @@ void copy_gpu_inplace(
kname << "gg"; kname << "gg";
break; break;
} }
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size(); kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
} }
kname << "_copy"; kname << "_copy";
kname << type_to_name(in) << type_to_name(out); kname << type_to_name(in) << type_to_name(out);
@@ -105,10 +122,8 @@ void copy_gpu_inplace(
compute_encoder.set_output_array(out, 1, out_offset); compute_encoder.set_output_array(out, 1, out_offset);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()}; std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()}; std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2); set_vector_bytes(compute_encoder, shape, ndim, 2);
} }
@@ -117,10 +132,6 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4); set_vector_bytes(compute_encoder, strides_out, ndim, 4);
} }
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -129,6 +140,11 @@ void copy_gpu_inplace(
data_size *= s; data_size *= s;
int rest = data_size / (dim0 * dim1); int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@@ -174,4 +190,31 @@ void copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s); in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
} }
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -37,4 +37,7 @@ void copy_gpu_inplace(
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);
// Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -15,11 +15,11 @@ void CustomKernel::eval_gpu(
std::vector<array> copies; std::vector<array> copies;
for (auto& out : outputs) { for (auto& out : outputs) {
// Copy from previous kernel
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) { if (init_value_) {
array init = array(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
copy_gpu(init, out, CopyType::Scalar, s); fill_gpu(copies.back(), out, s);
copies.push_back(init);
} }
} }
@@ -33,24 +33,22 @@ void CustomKernel::eval_gpu(
return copies.back(); return copies.back();
} }
}; };
std::vector<const array> checked_inputs; std::vector<array> checked_inputs;
for (const array& in : inputs) { for (const array& in : inputs) {
checked_inputs.push_back(check_input(in)); checked_inputs.push_back(check_input(in));
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
const auto& lib_name = name_; const auto& lib_name = name_;
auto lib = d.get_library(lib_name); auto lib =
if (lib == nullptr) { d.get_library(lib_name, [this] { return metal::utils() + source_; });
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto kernel = d.get_kernel(name_, lib); auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int index = 0; int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) { for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
auto shape_info = shape_infos_[i]; auto& shape_info = shape_infos_[i];
compute_encoder.set_input_array(in, index); compute_encoder.set_input_array(in, index);
index++; index++;
if (in.ndim() > 0) { if (in.ndim() > 0) {
@@ -69,7 +67,7 @@ void CustomKernel::eval_gpu(
} }
} }
} }
for (array out : outputs) { for (auto& out : outputs) {
compute_encoder.set_output_array(out, index); compute_encoder.set_output_array(out, index);
index++; index++;
} }
@@ -80,10 +78,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -20,7 +20,6 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12; constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
@@ -121,33 +120,34 @@ MTL::Library* load_library(
} // namespace } // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain(); enc_->retain();
} }
CommandEncoder::~CommandEncoder() { CommandEncoder::~CommandEncoder() {
enc->endEncoding(); enc_->endEncoding();
enc->release(); enc_->release();
} }
void CommandEncoder::set_input_array( void CommandEncoder::set_input_array(
const array& a, const array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) { if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier // Insert a barrier
enc->memoryBarrier(&r_buf, 1); enc_->memoryBarrier(&r_buf, 1);
// Remove the output // Remove the output
outputs.erase(it); outputs_.erase(it);
} }
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr()); auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() - auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents()); static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset; base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx); enc_->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_output_array( void CommandEncoder::set_output_array(
@@ -156,39 +156,25 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set // Add barriers before adding the output to the output set
set_input_array(a, idx, offset); set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) { if (concurrent_) {
concurrent_outputs.insert(buf); concurrent_outputs_.insert(buf);
} else { } else {
outputs.insert(buf); outputs_.insert(buf);
} }
} }
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreadgroups(grid_dims, group_dims);
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
} }
void CommandEncoder::dispatchThreads( void CommandEncoder::dispatchThreads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreads(grid_dims, group_dims);
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
} }
Device::Device() { Device::Device() {
@@ -199,12 +185,6 @@ Device::Device() {
Device::~Device() { Device::~Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) { for (auto& k : kernel_map_) {
k.second->release(); k.second->release();
} }
@@ -219,69 +199,134 @@ void Device::new_queue(int index) {
// Multiple threads can ask the device for queues // Multiple threads can ask the device for queues
// We lock this as a critical section for safety // We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index); debug_set_stream_queue_label(q, index);
if (!q) { if (!q) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Failed to make new command queue."); "[metal::Device] Failed to make new command queue.");
} }
queue_map_.insert({index, q}); stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
} }
int Device::get_command_buffer_ops(int index) { int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index); return get_stream_(index).buffer_ops;
return bit->second.first;
} }
void Device::increment_command_buffer_ops(int index) { void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index); get_stream_(index).buffer_ops++;
bit->second.first++;
} }
MTL::CommandBuffer* Device::get_command_buffer(int index) { MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
if (bit == buffer_map_.end()) { if (stream.buffer == nullptr) {
auto qit = queue_map_.find(index); stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (qit == queue_map_.end()) { if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Unable to create new command buffer"); "[metal::Device] Unable to create new command buffer");
} }
// Increment ref count so the buffer is not garbage collected // Increment ref count so the buffer is not garbage collected
cb->retain(); stream.buffer->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
} }
return bit->second.second; return stream.buffer;
} }
void Device::commit_command_buffer(int index) { void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
bit->second.second->commit(); stream.buffer->commit();
bit->second.second->release(); stream.buffer->release();
buffer_map_.erase(bit); stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
} }
void Device::end_encoding(int index) { void Device::end_encoding(int index) {
encoder_map_.erase(index); auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
} }
CommandEncoder& Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index); auto& stream = get_stream_(index);
if (eit == encoder_map_.end()) { if (stream.encoder == nullptr) {
auto cb = get_command_buffer(index); stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
eit = stream.fence = std::make_shared<Fence>(device_->newFence());
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
} }
return *(eit->second); return *stream.encoder;
} }
void Device::register_library( void Device::register_library(
@@ -293,20 +338,7 @@ void Device::register_library(
} }
} }
MTL::Library* Device::get_library_cache_(const std::string& lib_name) { MTL::Library* Device::build_library_(const std::string& source_string) {
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
auto ns_code = auto ns_code =
@@ -322,26 +354,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library // Throw error if unable to compile library
if (!mtl_lib) { if (!mtl_lib) {
std::ostringstream msg; std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n"; msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
if (error) { if (error) {
msg << error->localizedDescription()->utf8String() << "\n"; msg << error->localizedDescription()->utf8String() << "\n";
} }
@@ -465,68 +478,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
return kernel; return kernel;
} }
MTL::Library* Device::get_library(const std::string& name) { MTL::Library* Device::get_library_(const std::string& name) {
std::shared_lock lock(library_mtx_);
auto it = library_map_.find(name); auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr; return (it != library_map_.end()) ? it->second : nullptr;
} }
MTL::Library* Device::get_library( MTL::Library* Device::get_library(
const std::string& name, const std::string& name,
const std::string& source, const std::function<std::string(void)>& builder) {
bool cache /* = true */) { {
if (cache) { std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) { if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second; return it->second;
} }
} }
auto mtl_lib = get_library_(source); std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
if (cache) { return it->second;
library_map_.insert({name, mtl_lib});
} }
auto mtl_lib = build_library_(builder());
library_map_.insert({name, mtl_lib});
return mtl_lib; return mtl_lib;
} }
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_( MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) { const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) { if (funcs.empty()) {
@@ -547,34 +524,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
return lfuncs; return lfuncs;
} }
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Single writer allowed
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
auto pool = new_scoped_memory_pool();
// Pull kernel from library
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name, const std::string& base_name,
MTL::Library* mtl_lib, MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */, const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */, const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) { const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name; const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { {
return it->second; // Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
} }
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
} }
MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel(
@@ -583,16 +581,34 @@ MTL::ComputePipelineState* Device::get_kernel(
const std::string& hash_name /* = "" */, const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */, const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) { const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name; const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) { {
return it->second; // Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
} }
// Search for cached metal lib // Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name); MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions); void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
} }
Device& device(mlx::core::Device) { Device& device(mlx::core::Device) {

View File

@@ -7,6 +7,7 @@
#include <filesystem> #include <filesystem>
#include <functional> #include <functional>
#include <mutex> #include <mutex>
#include <shared_mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@@ -44,13 +45,13 @@ struct CommandEncoder {
struct ConcurrentContext { struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) { ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true; enc.concurrent_ = true;
} }
~ConcurrentContext() { ~ConcurrentContext() {
enc.concurrent = false; enc.concurrent_ = false;
enc.outputs.insert( enc.outputs_.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end()); enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs.clear(); enc.concurrent_outputs_.clear();
} }
private: private:
@@ -58,7 +59,7 @@ struct CommandEncoder {
}; };
MTL::ComputeCommandEncoder* operator->() { MTL::ComputeCommandEncoder* operator->() {
return enc; return enc_;
} }
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
@@ -69,18 +70,59 @@ struct CommandEncoder {
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }
~CommandEncoder(); ~CommandEncoder();
private: // Inputs to all kernels in the encoder including temporaries
void maybe_split(); std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0}; // Outputs of all kernels in the encoder including temporaries
MTL::CommandBuffer* cbuf; std::unordered_set<const void*> outputs() {
MTL::ComputeCommandEncoder* enc; return all_outputs_;
bool concurrent{false}; };
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs; private:
MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
}; };
class Device { class Device {
@@ -114,29 +156,9 @@ class Device {
} }
} }
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,
const std::string& source_string, const std::function<std::string(void)>& builder);
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::ComputePipelineState* get_kernel( MTL::ComputePipelineState* get_kernel(
const std::string& base_name, const std::string& base_name,
@@ -155,11 +177,20 @@ class Device {
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
void set_residency_set(const MTL::ResidencySet* residency_set);
private: private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string); MTL::Library* get_library_(const std::string& name);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc); MTL::Library* build_library_(const std::string& source_string);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
@@ -181,13 +212,22 @@ class Device {
const MTL::Function* mtl_function, const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions); const MTL::LinkedFunctions* linked_functions);
MTL::ComputePipelineState* get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, DeviceStream> stream_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_; std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_; std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_; std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_; const MTL::ResidencySet* residency_set_{nullptr};
}; };
Device& device(mlx::core::Device); Device& device(mlx::core::Device);

View File

@@ -27,4 +27,9 @@ void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value()); static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
} }
bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
value();
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -575,8 +575,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return; return;
} }
@@ -741,8 +740,8 @@ void fft_op(
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); d.add_temporaries(std::move(copies), s.index);
} }
void fft_op( void fft_op(
@@ -785,8 +784,7 @@ void nd_fft_op(
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(temp_arrs), s.index);
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
} }
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
return source.str(); return source.str();
} }
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) { void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
@@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
auto [n, m] = decompose_hadamard(in.shape(axis)); int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) { if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument( throw std::invalid_argument(
@@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel_name = kname.str(); auto kernel_name = kname.str();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m); auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard(); kernel_source << metal::utils() << codelet << metal::hadamard();
@@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
n, n,
m, m,
read_width); read_width);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
int batch_size = in.size() / n; int batch_size = in.size() / n;
int threads_per = n / max_radix; int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
};
if (m > 1) { if (m > 1) {
// When m is greater than 1, we decompose the // When m is greater than 1, we decompose the
// computation into two uploads to the GPU: // computation into two uploads to the GPU:
@@ -171,33 +164,17 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
temp.set_data(allocator::malloc_or_wait(temp.nbytes())); temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp); copies.push_back(temp);
launch_hadamard( launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel // Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP); threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per; batch_size = in.size() / m / read_width / threads_per;
launch_hadamard( launch_hadamard(temp, out, "m" + kernel_name, scale_);
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else { } else {
launch_hadamard( launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -64,8 +64,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name = lib_name; kernel_name = lib_name;
} }
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather(); kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype()); std::string out_type_str = get_type_string(out.dtype());
@@ -83,8 +82,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args, idx_args,
idx_arr, idx_arr,
idx_ndim); idx_ndim);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
@@ -236,8 +235,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name = kname.str(); kernel_name = kname.str();
} }
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter(); << metal::scatter();
@@ -264,7 +262,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break; break;
} }
if (reduce_type_ != Scatter::None) { if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str); op_type = fmt::format(fmt::runtime(op_type), out_type_str);
} }
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
@@ -277,8 +275,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx, nidx,
idx_args, idx_args,
idx_arr); idx_arr);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);

View File

@@ -1,100 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -1,26 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -1,12 +1,9 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/jit/steel_gemm.h"
@@ -27,37 +24,38 @@ MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
const auto& lib_name = kernel_name; auto lib = d.get_library(kernel_name, [&]() {
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source kernel_source << metal::utils() << metal::arange()
<< metal::utils() << metal::arange() << fmt::format(
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype())); arange_kernels,
lib = d.get_library(lib_name, kernel_source.str()); kernel_name,
} get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_unary_kernel( MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) { auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto u_def = get_template_definition( kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
"v" + lib_name, "unary_v", get_type_string(out_type), op); kernel_source << get_template_definition(
auto u2_def = get_template_definition( "v_" + lib_name, "unary_v", in_t, out_t, op);
"v2" + lib_name, "unary_v2", get_type_string(out_type), op); kernel_source << get_template_definition(
auto g_def = get_template_definition( "v2_" + lib_name, "unary_v2", in_t, out_t, op);
"g" + lib_name, "unary_g", get_type_string(out_type), op); kernel_source << get_template_definition(
kernel_source << metal::utils() << metal::unary_ops() << metal::unary() "gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
<< u_def << u2_def << g_def; return kernel_source.str();
lib = d.get_library(lib_name, kernel_source.str()); });
}
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -67,7 +65,7 @@ void add_binary_kernels(
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::ostringstream& kernel_source) { std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = { const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"}, {"vs", "binary_vs"},
{"sv", "binary_sv"}, {"sv", "binary_sv"},
@@ -78,31 +76,24 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"}, {"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"}, {"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"}, {"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"}, }};
{"g5", "binary_g_nd"}, for (auto& [name, func] : kernel_types) {
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
std::string template_def; std::string template_def;
if (name == "g4" || name == "g5") { template_def = get_template_definition(
int dim = std::stoi(name.substr(1)); name + "_" + lib_name,
template_def = get_template_definition( func,
name + lib_name, get_type_string(in_type),
func, get_type_string(out_type),
get_type_string(in_type), op);
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def; kernel_source << template_def;
} }
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
} }
MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_kernel(
@@ -111,14 +102,13 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(2); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary(); kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -128,15 +118,14 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(2); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two(); << metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -146,34 +135,26 @@ MTL::ComputePipelineState* get_ternary_kernel(
Dtype type, Dtype type,
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = { const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"}, {"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"}, {"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"}, {"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"}, {"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"}, }};
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary(); kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
std::string template_def; std::string template_def;
if (name == "g4" || name == "g5") { template_def = get_template_definition(
int dim = std::stoi(name.substr(1)); name + "_" + lib_name, func, get_type_string(type), op);
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def; kernel_source << template_def;
} }
lib = d.get_library(lib_name, kernel_source.str()); kernel_source << get_template_definition(
} "gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -183,17 +164,33 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in, const array& in,
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::copy() kernel_source << metal::utils() << metal::copy()
<< fmt::format( << get_template_definition(
copy_kernels, "s_" + lib_name, "copy_s", in_type, out_type)
lib_name, << get_template_definition(
get_type_string(in.dtype()), "v_" + lib_name, "copy_v", in_type, out_type)
get_type_string(out.dtype())); << get_template_definition(
lib = d.get_library(lib_name, kernel_source.str()); "g1_" + lib_name, "copy_g_nd1", in_type, out_type)
} << get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -203,8 +200,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise, bool precise,
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&] {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax() kernel_source << metal::utils() << metal::softmax()
<< fmt::format( << fmt::format(
@@ -212,8 +208,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
lib_name, lib_name,
get_type_string(out.dtype()), get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype())); get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -226,22 +222,29 @@ MTL::ComputePipelineState* get_scan_kernel(
const array& in, const array& in,
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) { auto out_type = get_type_string(out.dtype());
std::string op_name = "Cum" + reduce_type; std::string op = "Cum" + reduce_type + "<" + out_type + ">";
op_name[3] = toupper(op_name[3]); op[3] = toupper(op[3]);
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan() kernel_source << metal::utils() << metal::scan();
<< fmt::format( const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
scan_kernels, {"contig_", "contiguous_scan"},
lib_name, {"strided_", "strided_scan"},
get_type_string(in.dtype()), }};
get_type_string(out.dtype()), for (auto& [prefix, kernel] : scan_kernels) {
op_name, kernel_source << get_template_definition(
inclusive, prefix + lib_name,
reverse); kernel,
lib = d.get_library(lib_name, kernel_source.str()); get_type_string(in.dtype()),
} get_type_string(out.dtype()),
op,
in.itemsize() <= 4 ? 4 : 2,
inclusive,
reverse);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -253,8 +256,7 @@ MTL::ComputePipelineState* get_sort_kernel(
int bn, int bn,
int tn) { int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
@@ -279,8 +281,8 @@ MTL::ComputePipelineState* get_sort_kernel(
bn, bn,
tn); tn);
} }
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -292,15 +294,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
int bn, int bn,
int tn) { int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort(); kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = { std::array<std::pair<std::string, std::string>, 3> kernel_types = {
{"sort_", "mb_block_sort"}, {{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"}, {"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}}; {"merge_", "mb_block_merge"}}};
for (auto [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
kernel_source << get_template_definition( kernel_source << get_template_definition(
name + lib_name, name + lib_name,
func, func,
@@ -310,8 +311,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
bn, bn,
tn); tn);
} }
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -319,8 +320,7 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const array& out) { const array& out) {
auto lib = d.get_library(kernel_name); auto lib = d.get_library(kernel_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
std::string op_type = op_name(out); std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]); op_type[0] = std::toupper(op_name(out)[0]);
@@ -329,8 +329,8 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition( kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op); kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -344,8 +344,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
int ndim /* = -1 */, int ndim /* = -1 */,
int bm /* = -1 */, int bm /* = -1 */,
int bn /* = -1 */) { int bn /* = -1 */) {
auto lib = d.get_library(kernel_name); auto lib = d.get_library(kernel_name, [&]() {
if (lib == nullptr) {
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source; std::ostringstream kernel_source;
@@ -363,8 +362,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
kernel_source << get_template_definition( kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op); kernel_name, func_name, in_type, out_type, op);
} }
lib = d.get_library(kernel_name, kernel_source.str()); return kernel_source.str();
} });
auto st = d.get_kernel(kernel_name, lib); auto st = d.get_kernel(kernel_name, lib);
return st; return st;
} }
@@ -383,8 +382,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int wm, int wm,
int wn) { int wn) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused() << metal::steel_gemm_fused()
@@ -399,8 +397,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
"wn"_a = wn, "wn"_a = wn,
"trans_a"_a = transpose_a, "trans_a"_a = transpose_a,
"trans_b"_a = transpose_b); "trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);
} }
@@ -419,8 +417,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned) { bool k_aligned) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk() << metal::steel_gemm_splitk()
@@ -438,8 +435,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
"trans_b"_a = transpose_b, "trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned, "mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned); "k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -450,19 +447,19 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
const array& out, const array& out,
bool axbpy) { bool axbpy) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk() << metal::steel_gemm_splitk()
<< fmt::format( << fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels fmt::runtime(
: steel_gemm_splitk_accum_kernels, axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name, "name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()), "atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype())); "otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -482,8 +479,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned, bool mn_aligned,
bool k_aligned) { bool k_aligned) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value() auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype()) ? get_type_string((*mask_out).dtype())
@@ -507,8 +503,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
"trans_b"_a = transpose_b, "trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned, "mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned); "k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -527,8 +523,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
int tn, int tn,
bool contiguous) { bool contiguous) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value() auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype()) ? get_type_string((*mask_out).dtype())
@@ -550,8 +545,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
"tn"_a = tn, "tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "", "trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1"); "nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -567,8 +562,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization, int n_channel_specialization,
bool small_filter) { bool small_filter) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv() kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format( << fmt::format(
@@ -582,8 +576,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
"wn"_a = wn, "wn"_a = wn,
"n_channels"_a = n_channel_specialization, "n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter); "small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -597,8 +591,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm, int wm,
int wn) { int wn) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general() << metal::steel_conv_general()
@@ -611,8 +604,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
"bk"_a = bk, "bk"_a = bk,
"wm"_a = wm, "wm"_a = wm,
"wn"_a = wn); "wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -623,13 +616,12 @@ MTL::ComputePipelineState* get_fft_kernel(
const metal::MTLFCList& func_consts, const metal::MTLFCList& func_consts,
const std::string& template_def) { const std::string& template_def) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
std::string kernel_string; std::string kernel_string;
kernel_source << metal::fft() << template_def; kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);
} }
@@ -638,13 +630,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def) { const std::string& template_def) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name, [&]() {
if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized() kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def; << template_def;
lib = d.get_library(lib_name, kernel_source.str()); return kernel_source.str();
} });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }

View File

@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel( MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op); const std::string op);
@@ -208,10 +209,10 @@ get_template_definition(std::string name, std::string func, Args... args) {
}; };
(add_arg(args), ...); (add_arg(args), ...);
s << ">"; s << ">";
std::string base_string = R"( return fmt::format(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1}; "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
)"; name,
return fmt::format(base_string, name, s.str()); s.str());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,38 +1,26 @@
set( set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
BASE_HEADERS
bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h
)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
-gline-tables-only
-frecord-sources)
endif() endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
${METAL_FLAGS} -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air" COMMENT "Building ${TARGET}.air"
VERBATIM VERBATIM)
)
endfunction(build_kernel_base) endfunction(build_kernel_base)
function(build_kernel KERNEL) function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
cmake_path(GET KERNEL STEM TARGET) cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}") build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE) set(KERNEL_AIR
${TARGET}.air ${KERNEL_AIR}
PARENT_SCOPE)
endfunction(build_kernel) endfunction(build_kernel)
build_kernel(arg_reduce) build_kernel(arg_reduce)
@@ -43,105 +31,66 @@ build_kernel(random)
build_kernel(rms_norm) build_kernel(rms_norm)
build_kernel(rope) build_kernel(rope)
build_kernel( build_kernel(
scaled_dot_product_attention scaled_dot_product_attention scaled_dot_product_attention_params.h
scaled_dot_product_attention_params.h sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
steel/defines.h
steel/gemm/transforms.h
steel/utils.h
)
set( set(STEEL_HEADERS
STEEL_HEADERS steel/defines.h
steel/defines.h steel/utils.h
steel/utils.h steel/conv/conv.h
steel/conv/conv.h steel/conv/loader.h
steel/conv/loader.h steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_l.h steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_channel_n.h steel/conv/loaders/loader_general.h
steel/conv/loaders/loader_general.h steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv.h steel/conv/kernels/steel_conv_general.h
steel/conv/kernels/steel_conv_general.h steel/gemm/gemm.h
steel/gemm/gemm.h steel/gemm/mma.h
steel/gemm/mma.h steel/gemm/loader.h
steel/gemm/loader.h steel/gemm/transforms.h
steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_masked.h steel/gemm/kernels/steel_gemm_splitk.h)
steel/gemm/kernels/steel_gemm_splitk.h
)
if (NOT MLX_METAL_JIT) if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h) build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h) build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h) build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h) build_kernel(copy copy.h)
build_kernel( build_kernel(fft fft.h fft/radix.h fft/readwrite.h)
fft build_kernel(
fft.h reduce
fft/radix.h atomic.h
fft/readwrite.h reduction/ops.h
) reduction/reduce_init.h
build_kernel( reduction/reduce_all.h
reduce reduction/reduce_col.h
atomic.h reduction/reduce_row.h)
reduction/ops.h build_kernel(quantized quantized.h ${STEEL_HEADERS})
reduction/reduce_init.h build_kernel(scan scan.h)
reduction/reduce_all.h build_kernel(softmax softmax.h)
reduction/reduce_col.h build_kernel(sort sort.h)
reduction/reduce_row.h build_kernel(ternary ternary.h ternary_ops.h)
) build_kernel(unary unary.h unary_ops.h)
build_kernel( build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
quantized build_kernel(steel/conv/kernels/steel_conv_general ${STEEL_HEADERS})
quantized.h build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
${STEEL_HEADERS} build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
) build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(gemv_masked steel/utils.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(
steel/conv/kernels/steel_conv
${STEEL_HEADERS}
)
build_kernel(
steel/conv/kernels/steel_conv_general
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_fused
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_masked
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS}
)
build_kernel(gemv_masked steel/utils.h)
endif() endif()
add_custom_command( add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR} DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib" COMMENT "Building mlx.metallib"
VERBATIM VERBATIM)
)
add_custom_target( add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib)
mlx-metallib
DEPENDS
${MLX_METAL_PATH}/mlx.metallib
)
add_dependencies( add_dependencies(mlx mlx-metallib)
mlx
mlx-metallib
)
# Install metallib # Install metallib
include(GNUInstallDirs) include(GNUInstallDirs)
@@ -149,5 +98,4 @@ include(GNUInstallDirs)
install( install(
FILES ${MLX_METAL_PATH}/mlx.metallib FILES ${MLX_METAL_PATH}/mlx.metallib
DESTINATION ${CMAKE_INSTALL_LIBDIR} DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT metallib COMPONENT metallib)
)

View File

@@ -70,16 +70,16 @@ IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
} }
template <typename T, typename Op, int N_READS> template <typename T, typename Op, int N_READS = 4>
[[kernel]] void arg_reduce_general( [[kernel]] void arg_reduce_general(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]], device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]], const constant int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]], const constant size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]], const constant size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]], const constant size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]], const constant size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]], const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]], uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
@@ -159,28 +159,12 @@ template <typename T, typename Op, int N_READS>
} }
} }
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off // clang-format off
#define instantiate_arg_reduce(name, itype) \ #define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \ instantiate_kernel( \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax) "argmin_" #name, arg_reduce_general, itype, ArgMin<itype>) \
instantiate_kernel( \
"argmax_" #name, arg_reduce_general, itype, ArgMax<itype>)
instantiate_arg_reduce(bool_, bool) instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t) instantiate_arg_reduce(uint8, uint8_t)

View File

@@ -93,7 +93,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
@@ -109,27 +109,11 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op, int DIM> template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -140,7 +124,16 @@ template <typename T, typename U, typename Op>
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); auto idx = elem_to_loc_2_nd(
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
c[out_idx] = Op()(a[idx.x], b[idx.y]); auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;
idx.y += b_xstride;
}
} }

View File

@@ -9,20 +9,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_integer(op) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -217,7 +217,7 @@ struct Power {
template <> template <>
complex64_t operator()(complex64_t x, complex64_t y) { complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real); auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta; auto phase = y.imag * x_ln_r + y.real * x_theta;

View File

@@ -118,7 +118,7 @@ template <typename T, typename U, typename Op>
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
@@ -137,32 +137,13 @@ template <typename T, typename U, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template <typename T, typename U, typename Op, int DIM> template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
@@ -174,9 +155,18 @@ template <typename T, typename U, typename Op>
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); auto idx = elem_to_loc_2_nd(
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto out = Op()(a[idx.x], b[idx.y]); auto xshape = shape[ndim - 1];
c[out_idx] = out[0]; size_t out_idx =
d[out_idx] = out[1]; N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx++] = out[1];
idx.x += a_xstride;
idx.y += b_xstride;
}
} }

View File

@@ -7,20 +7,18 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \

View File

@@ -113,6 +113,7 @@ template <typename T, int N>
for (int i = N - 1; i >= 0; --i) { for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]); int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]); int ws_ = (wS % params->wS[i]);
out += ws_ * kernel_stride;
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
@@ -126,7 +127,6 @@ template <typename T, int N>
oS /= params->oS[i]; oS /= params->oS[i];
wS /= params->wS[i]; wS /= params->wS[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i]; kernel_stride *= params->wS[i];
} }

View File

@@ -71,21 +71,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM> template <typename T, typename U, int N = 1>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -94,10 +80,22 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx = int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
src_idx += src_xstride;
}
} }
template <typename T, typename U> template <typename T, typename U>
@@ -136,20 +134,7 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM> template <typename T, typename U, int N = 1>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -158,7 +143,22 @@ template <typename T, typename U>
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto idx = elem_to_loc_2_nd(
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); {N * index.x, index.y, index.z},
dst[dst_idx] = static_cast<U>(src[src_idx]); src_shape,
src_strides,
dst_strides,
ndim);
if (N == 1) {
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);
idx.x += src_xstride;
idx.y += dst_xstride;
}
} }

View File

@@ -16,12 +16,8 @@
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \ instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \ instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -8,6 +8,7 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) { inline U load_vector(const device T* x, thread U* x_thread) {
@@ -371,6 +372,64 @@ struct QuantizedBlockLoader {
} }
}; };
template <typename T, int group_size, int bits, int D>
METAL_FUNC void qmv_quad_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = D / QUAD_SIZE;
constexpr int packs_per_thread = values_per_thread / pack_factor;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int results_per_quadgroup = 8;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
U s = sl[0];
U b = bl[0];
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
for (int row = 0; row < results_per_quadgroup; row++) {
result[row] = quad_sum(result[row]);
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
y[row * quads_per_simd] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
METAL_FUNC void qmv_fast_impl( METAL_FUNC void qmv_fast_impl(
const device uint32_t* w, const device uint32_t* w,
@@ -586,10 +645,10 @@ METAL_FUNC void qmv_impl(
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
METAL_FUNC void qvm_impl( METAL_FUNC void qvm_impl(
const device T* x,
const device uint32_t* w, const device uint32_t* w,
const device T* scales, const device T* scales,
const device T* biases, const device T* biases,
const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
const constant int& out_vec_size, const constant int& out_vec_size,
@@ -697,16 +756,16 @@ template <
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
METAL_FUNC void qmm_t_impl( METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w, const device uint32_t* w,
const device T* scales, const device T* scales,
const device T* biases, const device T* biases,
const device T* x,
device T* y, device T* y,
threadgroup T* Xs, threadgroup T* Xs,
threadgroup T* Ws, threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K, const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -818,16 +877,16 @@ template <
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
METAL_FUNC void qmm_n_impl( METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w, const device uint32_t* w,
const device T* scales, const device T* scales,
const device T* biases, const device T* biases,
const device T* x,
device T* y, device T* y,
threadgroup T* Xs, threadgroup T* Xs,
threadgroup T* Ws, threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K, const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -942,6 +1001,45 @@ METAL_FUNC void qmm_n_impl(
} }
} }
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
const device T*& biases,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T> template <typename T>
METAL_FUNC void adjust_matrix_offsets( METAL_FUNC void adjust_matrix_offsets(
const device T*& x, const device T*& x,
@@ -996,7 +1094,58 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_quad_impl<T, group_size, bits, D>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast( [[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@@ -1005,9 +1154,35 @@ template <typename T, int group_size, int bits>
device T* y [[buffer(4)]], device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_fast_impl<T, group_size, bits>( qmv_fast_impl<T, group_size, bits>(
w, w,
scales, scales,
@@ -1021,7 +1196,7 @@ template <typename T, int group_size, int bits>
simd_lid); simd_lid);
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@@ -1030,9 +1205,35 @@ template <typename T, const int group_size, const int bits>
device T* y [[buffer(4)]], device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_impl<T, group_size, bits>( qmv_impl<T, group_size, bits>(
w, w,
scales, scales,
@@ -1046,23 +1247,49 @@ template <typename T, const int group_size, const int bits>
simd_lid); simd_lid);
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm( [[kernel]] void qvm(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qvm_impl<T, group_size, bits>( qvm_impl<T, group_size, bits>(
x,
w, w,
scales, scales,
biases, biases,
x,
y, y,
in_vec_size, in_vec_size,
out_vec_size, out_vec_size,
@@ -1076,18 +1303,27 @@ template <
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N, const bool aligned_N,
const bool batched,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void qmm_t( [[kernel]] void qmm_t(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]], const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]], const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]], const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1099,26 +1335,53 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded]; threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>( qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template < template <
typename T, typename T,
const int group_size, const int group_size,
const int bits, const int bits,
const bool batched,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void qmm_n( [[kernel]] void qmm_n(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
device T* y [[buffer(4)]], device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]], const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]], const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]], const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1131,8 +1394,27 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded]; threadgroup T Ws[BK * BN_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_n_impl<T, group_size, bits, BM, BK, BN>( qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
@@ -1141,23 +1423,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]], device T* y [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
device T* y [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]], const constant int& x_batch_ndims [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]], const constant int* x_shape [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]], const constant size_t* x_strides [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int& w_batch_ndims [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]], const constant int* w_shape [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]], const constant size_t* w_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]], const constant size_t* s_strides [[buffer(13)]],
const constant int* x_shape [[buffer(14)]], const constant size_t* b_strides [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]], const constant int& batch_ndims [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]], const constant int* batch_shape [[buffer(16)]],
const constant int* w_shape [[buffer(17)]], const device uint32_t* lhs_indices [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]], const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]], const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]], const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1202,23 +1484,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]], const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]], device T* y [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
device T* y [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]], const constant int& x_batch_ndims [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]], const constant int* x_shape [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]], const constant size_t* x_strides [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int& w_batch_ndims [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]], const constant int* w_shape [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]], const constant size_t* w_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]], const constant size_t* s_strides [[buffer(13)]],
const constant int* x_shape [[buffer(14)]], const constant size_t* b_strides [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]], const constant int& batch_ndims [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]], const constant int* batch_shape [[buffer(16)]],
const constant int* w_shape [[buffer(17)]], const device uint32_t* lhs_indices [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]], const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]], const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]], const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1259,27 +1541,27 @@ template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits> template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm( [[kernel]] void bs_qvm(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]], device T* y [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]], const constant int& in_vec_size [[buffer(5)]],
device T* y [[buffer(6)]], const constant int& out_vec_size [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]], const constant int& x_batch_ndims [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]], const constant int* x_shape [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]], const constant size_t* x_strides [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int& w_batch_ndims [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]], const constant int* w_shape [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]], const constant size_t* w_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]], const constant size_t* s_strides [[buffer(13)]],
const constant int* x_shape [[buffer(14)]], const constant size_t* b_strides [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]], const constant int& batch_ndims [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]], const constant int* batch_shape [[buffer(16)]],
const constant int* w_shape [[buffer(17)]], const device uint32_t* lhs_indices [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]], const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]], const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]], const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1306,10 +1588,10 @@ template <typename T, int group_size, int bits>
b_strides, b_strides,
tid); tid);
qvm_impl<T, group_size, bits>( qvm_impl<T, group_size, bits>(
x,
w, w,
scales, scales,
biases, biases,
x,
y, y,
in_vec_size, in_vec_size,
out_vec_size, out_vec_size,
@@ -1327,28 +1609,28 @@ template <
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_t( [[kernel]] void bs_qmm_t(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]], device T* y [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]], const constant int& K [[buffer(5)]],
device T* y [[buffer(6)]], const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]], const constant int& x_batch_ndims [[buffer(8)]],
const constant int& K [[buffer(9)]], const constant int* x_shape [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]], const constant size_t* x_strides [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]], const constant int& w_batch_ndims [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]], const constant int* w_shape [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]], const constant size_t* w_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]], const constant size_t* s_strides [[buffer(14)]],
const constant int* x_shape [[buffer(15)]], const constant size_t* b_strides [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(17)]],
const constant int* w_shape [[buffer(18)]], const device uint32_t* lhs_indices [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]], const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]], const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]], const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1383,7 +1665,7 @@ template <
b_strides, b_strides,
tid); tid);
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>( qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template < template <
@@ -1394,28 +1676,28 @@ template <
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void bs_qmm_n( [[kernel]] void bs_qmm_n(
const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(1)]],
const device T* scales [[buffer(2)]], const device T* biases [[buffer(2)]],
const device T* biases [[buffer(3)]], const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]], device T* y [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]], const constant int& K [[buffer(5)]],
device T* y [[buffer(6)]], const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]], const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]], const constant int& x_batch_ndims [[buffer(8)]],
const constant int& K [[buffer(9)]], const constant int* x_shape [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]], const constant size_t* x_strides [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]], const constant int& w_batch_ndims [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]], const constant int* w_shape [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]], const constant size_t* w_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]], const constant size_t* s_strides [[buffer(14)]],
const constant int* x_shape [[buffer(15)]], const constant size_t* b_strides [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]], const constant int& batch_ndims [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]], const constant int* batch_shape [[buffer(17)]],
const constant int* w_shape [[buffer(18)]], const device uint32_t* lhs_indices [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]], const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]], const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]], const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1451,17 +1733,17 @@ template <
b_strides, b_strides,
tid); tid);
qmm_n_impl<T, group_size, bits, BM, BK, BN>( qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize( METAL_FUNC void affine_quantize_impl(
const device T* w [[buffer(0)]], const device T* w,
device uint8_t* out [[buffer(1)]], device uint8_t* out,
device T* scales [[buffer(2)]], device T* scales,
device T* biases [[buffer(3)]], device T* biases,
uint2 index [[thread_position_in_grid]], uint2 index,
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim) {
constexpr T eps = T(1e-7); constexpr T eps = T(1e-7);
constexpr int simd_size = 32; constexpr int simd_size = 32;
constexpr int uint8_bits = 8; constexpr int uint8_bits = 8;
@@ -1538,6 +1820,18 @@ template <typename T, const int group_size, const int bits>
} }
} }
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
affine_quantize_impl<T, group_size, bits>(
w, out, scales, biases, index, grid_dim);
}
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases( [[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]], const device T* w [[buffer(0)]],
@@ -1601,3 +1895,41 @@ template <typename T, const int group_size, const int bits>
out[oindex + i] = scale * d + bias; out[oindex + i] = scale * d + bias;
} }
} }
template <typename T, const int group_size, const int bits>
[[kernel]] void kv_update(
const device T* new_keys [[buffer(0)]],
const device T* new_values [[buffer(1)]],
device uint8_t* keys [[buffer(2)]],
device T* key_scales [[buffer(3)]],
device T* key_biases [[buffer(4)]],
device uint8_t* values [[buffer(5)]],
device T* value_scales [[buffer(6)]],
device T* value_biases [[buffer(7)]],
const constant int& offset,
const constant int& batch_stride,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
// Get the right offset in the thing
// Need to use the head dim too
constexpr int pack_factor = 8 / bits;
uint batch_idx = index.y * batch_stride * 4 + offset;
new_keys += index.y * 128;
new_values += index.y * 128;
// uint batch_idx = offset;
// // Index to correct slice
uint group_idx = batch_idx * pack_factor / group_size;
keys += batch_idx;
key_scales += group_idx;
key_biases += group_idx;
values += batch_idx;
value_scales += group_idx;
value_biases += group_idx;
uint2 new_index = {index.x, 0};
affine_quantize_impl<T, group_size, bits>(
new_keys, keys, key_scales, key_biases, new_index, grid_dim);
affine_quantize_impl<T, group_size, bits>(
new_values, values, value_scales, value_biases, new_index, grid_dim);
}

View File

@@ -5,67 +5,104 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h" #include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \ #define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \ #name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \ name, \
type, \ type, \
group_size, \ group_size, \
bits) bits)
#define instantiate_quantized_types(name, group_size, bits) \ #define instantiate_quantized_batched(name, type, group_size, bits, batched) \
instantiate_quantized(name, float, group_size, bits) \ instantiate_kernel( \
instantiate_quantized(name, float16_t, group_size, bits) \ #name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
instantiate_quantized(name, bfloat16_t, group_size, bits) name, \
type, \
group_size, \
bits, \
batched)
#define instantiate_quantized_groups(name, bits) \ #define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_quantized_types(name, 128, bits) \ instantiate_kernel( \
instantiate_quantized_types(name, 64, bits) \ #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
instantiate_quantized_types(name, 32, bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \ name, \
type, \ type, \
group_size, \ group_size, \
bits, \ bits, \
aligned) aligned)
#define instantiate_quantized_types_aligned(name, group_size, bits) \ #define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \ instantiate_kernel( \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \ #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \ name, \
instantiate_quantized_aligned(name, float, group_size, bits, false) \ type, \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \ group_size, \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false) bits, \
aligned, \
batched)
#define instantiate_quantized_groups_aligned(name, bits) \ #define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
instantiate_quantized_types_aligned(name, 128, bits) \ instantiate_kernel( \
instantiate_quantized_types_aligned(name, 64, bits) \ #name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
instantiate_quantized_types_aligned(name, 32, bits) name, \
type, \
group_size, \
bits, \
D, \
batched)
#define instantiate_quantized_all_aligned(name) \ #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_groups_aligned(name, 2) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_groups_aligned(name, 4) \ instantiate_quantized_batched(name, type, group_size, bits, 0)
instantiate_quantized_groups_aligned(name, 8) \
instantiate_quantized_all_aligned(qmm_t) #define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -69,9 +69,9 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
device char* out, device char* out,
device const bool& odd, device const bool& odd,
device const uint& bytes_per_key, device const uint& bytes_per_key,
device const int& ndim, constant const int& ndim,
device const int* key_shape, constant const int* key_shape,
device const size_t* key_strides, constant const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]], uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x; auto kidx = 2 * index.x;

View File

@@ -1,11 +1,11 @@
#include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;
@@ -886,6 +886,9 @@ template <
} }
} }
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \ #define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \ itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
@@ -922,548 +925,42 @@ instantiate_fast_inference_self_attention_kernel(
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
template < // SDPA vector instantiations
typename T, #define instantiate_sdpa_vector(type, head_dim) \
typename T2, instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim)
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; #define instantiate_sdpa_vector_heads(type) \
#pragma clang loop unroll(full) instantiate_sdpa_vector(type, 64) \
for (uint i = 0; i < 8; i++) { instantiate_sdpa_vector(type, 96) \
smemFlush instantiate_sdpa_vector(type, 128)
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; instantiate_sdpa_vector_heads(float)
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; instantiate_sdpa_vector_heads(bfloat16_t)
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; instantiate_sdpa_vector_heads(float16_t)
const device T* baseK = // Quantized SDPA vector instantiations
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; #define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
const device T* baseQ = Q + tgroup_query_head_offset; instantiate_kernel( \
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector, type, head_dim, group_size, bits)
device T4* simdgroupQueryData = (device T4*)baseQ; #define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; #define instantiate_quant_sdpa_vector_group_size(type, heads) \
float threadAccum[ACCUM_PER_GROUP]; instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)
#pragma clang loop unroll(full) #define instantiate_quant_sdpa_vector_heads(type) \
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; instantiate_quant_sdpa_vector_group_size(type, 64) \
threadAccumIndex++) { instantiate_quant_sdpa_vector_group_size(type, 96) \
threadAccum[threadAccumIndex] = -INFINITY; instantiate_quant_sdpa_vector_group_size(type, 128)
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; instantiate_quant_sdpa_vector_heads(float)
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; instantiate_quant_sdpa_vector_heads(bfloat16_t)
const bool LAST_TILE_ALIGNED = instantiate_quant_sdpa_vector_heads(float16_t)
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4; // clang-format on
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}

Some files were not shown because too many files have changed in this diff Show More