Compare commits

..

100 Commits

Author SHA1 Message Date
Awni Hannun
bb303c45a5 version (#1617) 2024-11-22 12:00:03 -08:00
Alex Barron
6f7986d592 Cleaner qmv/qvm (#1616) 2024-11-22 11:14:08 -08:00
Awni Hannun
7cbb4aef17 Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Jagrit Digani
02bec0bb6d Matrix Attention kernel (#1610)
* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c 3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b Reduce specializations (#1607)
* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
2024-11-21 19:53:00 -08:00
Awni Hannun
dcca0d7477 contiguous op / prim (#1612) 2024-11-21 19:51:49 -08:00
Cocoa
0d5e7716ad fix typo: accross -> across (#1609)
Signed-off-by: Cocoa <i@uwucocoa.moe>
2024-11-20 15:30:51 -08:00
Angelos Katharopoulos
d8c824c594 Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00
Awni Hannun
61d787726a Fix view scalar bug segfault (#1603)
* fix view scalar bug

* fix view scalar bug

* one more fix
2024-11-19 10:54:05 -08:00
Angelos Katharopoulos
5e89aace9b Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
2af7e8a9a6 fix cmake version (#1601) 2024-11-19 08:45:05 -08:00
Awni Hannun
2419edd5b2 Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
2024-11-18 19:52:00 -08:00
Awni Hannun
bf481e8e5d Fix sibling leak (#1590)
* add test

* fix + test

* fix fix
2024-11-18 19:17:01 -08:00
Awni Hannun
9d7fa6b8e6 Use osx deployment target to pick Metal version (#1595)
* choose metal based on deployment target rather than system version

* nit

* unused compile def
2024-11-18 19:16:49 -08:00
Angelos Katharopoulos
073076ac7d 2-Pass Sdpa Inference Kernel (#1597) 2024-11-18 17:31:53 -08:00
Awni Hannun
9bd03dd9b4 More buffer donation with no-ops (#1591)
* more donation

* fix test

* fix build
2024-11-18 08:35:41 -08:00
Awni Hannun
6931f84412 fix dispatch threads for a few kernels (#1594) 2024-11-18 08:35:25 -08:00
xnorai
16ec0556a0 Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)
* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB

* Set the upper size limit for the header to 100K as in Rust safetensors
2024-11-18 07:22:51 -08:00
Awni Hannun
610af352d4 Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT

* fix extension

* fix extension build

* fix extension build

* Update utils.h
2024-11-15 16:54:36 -08:00
Awni Hannun
b35f1e3c9c fix donation in sdpa (#1587) 2024-11-13 17:21:13 -08:00
Awni Hannun
dfa0b9aab4 Cpu fast quantize (#1578)
* cpu quantize

* fix
2024-11-08 20:10:39 -08:00
Alex Barron
a4c47b0276 OOB QMV fix (#1579)
* fix oob access in qmv

* skip more

* fix small case
2024-11-08 17:59:45 -08:00
Alex Barron
111fefd5e9 Fix OOB access in qmv (#1577)
* fix oob access in qmv

* skip more
2024-11-08 15:41:30 -08:00
Awni Hannun
c1fe1ef081 Bfs width limit (#1568)
* width limit

* fix

* large limit

* put env vars in env namespace
2024-11-08 15:00:46 -08:00
Awni Hannun
8c34c9dac4 throw for invalid case and remove test (#1575) 2024-11-08 12:04:03 -08:00
Awni Hannun
91c0277356 fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Awni Hannun
9f0d5c12fc Fully wrap the command encoder (#1572)
* fully wrap the command encoder

* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Awni Hannun
59247c2b62 add groups in conv2d (#1569) 2024-11-07 13:57:53 -08:00
Awni Hannun
9a3842a2d9 fix (#1566) 2024-11-06 17:10:33 -08:00
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
Awni Hannun
76f275b4df error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Awni Hannun
f1951d6cce Use fewer barriers (#1561)
* use fewer barriers

* comment
2024-11-04 10:26:49 -08:00
Angelos Katharopoulos
62f297b51d Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Chris Offner
46d8b16ab4 Fix vmap example in docs (#1556) 2024-11-02 17:44:14 -07:00
Chris Offner
42533931fa Fix typo "it's" -> "its" (#1555) 2024-11-02 06:06:34 -07:00
Awni Hannun
9bd3a7102f add python 3.13 to circle (#1553) 2024-11-01 20:55:35 -07:00
Alex Barron
9e516b71ea Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
eac961ddb1 patch (#1550) 2024-10-31 16:10:14 -07:00
Awni Hannun
57c6aa7188 fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
Awni Hannun
cde5b4ad80 patch (#1546) 2024-10-30 19:31:22 -07:00
Awni Hannun
4f72c66911 improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Jagrit Digani
960e3f0f05 Gemm update (#1518) 2024-10-30 19:30:28 -07:00
Awni Hannun
884af42da2 Fix thread group for large arrays (#1543)
* fix thread group for large arrays

* comment

* one more
2024-10-30 16:25:12 -07:00
Alex Barron
048fabdabd Fix vmap constant output size (#1524)
* use inputs to determine output size

* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Léo
917252a5a1 Add favicon to docs (#1545)
* add sphinx's html_favicon config

* removed unneeded newline

* ran pre-commit hooks
2024-10-30 13:54:13 -07:00
Carlo Cabrera
1a992e31e8 Skip using Residency sets in VMs (#1537)
* Skip using Residency sets in VMs

Attempting to use residency sets in a VM throws[^1]

    libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to construct residency set.

Not quite sure if this is the best fix, but it does make the error go
away.

Note that it was previously possible to run simple programs that used
mlx in a VM prior to 0eb56d5be0. See
related discussion at Homebrew/homebrew-core#195627.

[^1]: https://github.com/Homebrew/homebrew-core/actions/runs/11525831492/job/32105148462#step:3:56

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* change residency check

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-29 19:37:23 -07:00
Awni Hannun
d2ff04a4f2 fix format (#1539) 2024-10-28 18:29:14 -07:00
Awni Hannun
015c247393 change wino dispatch conditoin (#1534) 2024-10-28 11:13:44 -07:00
Awni Hannun
d3cd26820e Faster bits and bernoulli (#1535)
* faster bits and bernoulli

* fix bernoulli
2024-10-28 11:11:00 -07:00
Awni Hannun
91f6c499d7 fix (#1529) 2024-10-25 19:25:35 -07:00
Awni Hannun
35e9c87ab9 patch bump (#1528) 2024-10-25 13:13:23 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
c21331d47f version bump (#1477) 2024-10-10 13:05:17 -07:00
Awni Hannun
e1c9600da3 Add mx.random.permutation (#1471)
* random permutation

* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30 consistently handle all -inf in softmax (#1470) 2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087 Fix array is_available race cases (#1468) 2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739 Add the roll op (#1455) 2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d Bump nanobind to 2.2 (#1461)
* bump nanobind

* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40 Fix deep recursion with siblings (#1462)
* fix recursion with siblings

* fix

* add test

* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
95d04805b3 Fix complex power on Metal (#1460) 2024-10-06 19:58:30 -07:00
Awni Hannun
e4534dac17 Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
fef3c4ec1d Fix mpi test in CI (#1456)
* Fix mpi test in CI

* Set bind to none
2024-10-06 06:09:17 -07:00
Awni Hannun
1bdc038bf9 fix argpartition + faster {arg} sorts / partitions (#1453) 2024-10-03 14:21:25 -07:00
Awni Hannun
5523d9c426 faster cpu indexing (#1450) 2024-10-03 13:53:47 -07:00
Angelos Katharopoulos
d878015228 Fix normalization check_input (#1452) 2024-10-03 13:26:56 -07:00
Cheng
5900e3249f Fix building on Linux (#1446) 2024-09-30 07:00:39 -07:00
Angelos Katharopoulos
bacced53d3 Fix row reduce with very few rows (#1447) 2024-09-29 20:00:35 -07:00
Lucas Newman
4a64d4bff1 Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
223 changed files with 11450 additions and 6409 deletions

View File

@@ -13,8 +13,62 @@ parameters:
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@@ -31,7 +85,7 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@@ -77,13 +131,13 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
brew install python@3.9
brew install openmpi
python3.8 -m venv env
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install numpy
pip install torch
pip install tensorflow
@@ -105,7 +159,7 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
@@ -172,7 +226,7 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install twine
@@ -208,7 +262,7 @@ jobs:
- store_artifacts:
path: dist/
build_linux_test_release:
build_linux_release:
parameters:
python_version:
type: string
@@ -237,12 +291,13 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.1.0
pip install nanobind==2.2.0
pip install --upgrade setuptools
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
@@ -253,6 +308,11 @@ jobs:
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -272,6 +332,7 @@ workflows:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
@@ -288,9 +349,17 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
@@ -317,7 +386,7 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
@@ -328,17 +397,17 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_test_release:
- build_linux_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.18.0)
set(MLX_VERSION 0.21.0)
endif()
# --------------------- Processor tests -------------------------
@@ -47,17 +47,20 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
@@ -86,25 +89,27 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_VERSION} LESS 14.0)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
# Get the metal version
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
@@ -112,8 +117,6 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if(MLX_BUILD_CPU)

View File

@@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon,
MLX is an array framework for machine learning on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:

View File

@@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
dst[tuple(idx)] = x
mx.eval(dst)
idx = []
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
if device == torch.device("mps"):
torch.mps.synchronize()
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
@@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(200_000,),
(20_000_000,),
(10000, 64),
(100, 64),
@@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -1,62 +1,189 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
from time_utils import time_fn
import numpy as np
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def bench(f, *args):
for i in range(N_warmup):
f(*args)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
time_self_attention_sdpa()
time_self_attention_primitives()
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -0,0 +1,58 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
loops = 10
def attention(q, k, v):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
def sdpa(q, k, v):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -60,6 +60,7 @@ html_theme_options = {
},
}
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output ---------------------------------------------

View File

@@ -1,3 +1,5 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides

View File

@@ -494,7 +494,7 @@ below.
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -509,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
@@ -530,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Metal kernel cache persists across reboots.
Troubleshooting
^^^^^^^^^^^^^^^
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again.
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -12,5 +12,4 @@ Fast
layer_norm
rope
scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -16,3 +16,5 @@ Linear Algebra
cross
qr
svd
eigvalsh
eigh

View File

@@ -14,6 +14,7 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -12,6 +12,7 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
@@ -41,6 +42,7 @@ Layers
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU

View File

@@ -80,6 +80,7 @@ Operations
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
@@ -125,11 +126,13 @@ Operations
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save

View File

@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
truncated_normal
uniform
laplace
permutation

View File

@@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
@@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python
def gelu(x):
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0)))
Compiling Training Graphs
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`

View File

@@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
function. To get the second derivative you can do:
.. code-block:: shell
@@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
graphs.
.. note::
@@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
where the vectorized axes should be in the outputs.
Let's time these two different versions:
@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
general, MLX has limited support for operations for which output
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:

View File

@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -109,14 +109,14 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

View File

@@ -3,10 +3,10 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
gradient with respect to the function's input.

View File

@@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
- Multiple arrays
* - Safetensors
- ``.safetensors``
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
:func:`load` depends on the format.
Here's an example of saving a single array to a file:

View File

@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
without needing to move them from one memory location to another. For example:
.. code-block:: python

View File

@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim if needed
if (!contiguous_kernel) {
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
}
// We launch 1 thread for each input and make sure that the number of
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
#else // Metal is not available

View File

@@ -2,7 +2,6 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T>
@@ -60,4 +59,4 @@ template <typename T>
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
instantiate_axpby(complex64, complex64_t);

View File

@@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.17.0",
"nanobind==2.1.0",
"mlx>=0.18.0",
"nanobind==2.2.0",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.17.0
nanobind==2.1.0
mlx>=0.21.0
nanobind==2.2.0

View File

@@ -28,10 +28,19 @@ endif()
if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set_and_check(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS}
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
)
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif()
set_target_properties(mlx PROPERTIES
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
}
void free(Buffer buffer) {
return allocator().free(buffer);
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -95,13 +95,29 @@ void array::detach() {
array_desc_->primitive = nullptr;
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
}
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this});
} else {
wait();
}
}
@@ -162,8 +178,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
@@ -196,6 +214,8 @@ array::~array() {
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
@@ -242,25 +262,46 @@ array::ArrayDesc::~ArrayDesc() {
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
std::unordered_map<std::uintptr_t, array> input_map;
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
}
}
ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
}
}

View File

@@ -344,11 +344,33 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
}
enum Status { unscheduled, scheduled, available };
enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
bool is_available() const {
return status() == Status::available;
}
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const {
return array_desc_->status;

View File

@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -18,49 +18,61 @@ void _qmm_t_4_64(
const float* biases,
int M,
int N,
int K) {
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
*result = simd_reduce_add(sum);
result++;
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
x += K;
}
}
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(),
M,
N,
K);
K,
B,
batched_w);
} else {
eval(inputs, out);
}

View File

@@ -33,8 +33,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x) {
x *= 1.442695; // multiply with log_2(e)
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
@@ -53,7 +53,9 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

View File

@@ -31,6 +31,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@@ -2,46 +2,12 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
return move_or_copy(in, out, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
move_or_copy(in, out, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
move_or_copy(inputs[0], out);
}
void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
move_or_copy(inputs[j], outputs[i]);
}
}
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
move_or_copy(inputs[i], outputs[i]);
}
}
@@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
move_or_copy(in, out, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
move_or_copy(inputs[0], out);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
move_or_copy(in, out, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -4,6 +4,8 @@
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
@@ -12,22 +14,7 @@
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -44,15 +31,41 @@ void* compile(
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,8 +103,8 @@ void* compile(
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
@@ -103,10 +116,10 @@ void* compile(
}
// load library
libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -114,7 +127,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}
@@ -266,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
auto contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -316,10 +329,7 @@ void Compiled::eval_cpu(
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +344,8 @@ void Compiled::eval_cpu(
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,13 +3,8 @@
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

View File

@@ -1,14 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -81,11 +80,18 @@ void gather(
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]);
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
@@ -99,9 +105,10 @@ void gather(
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
}
@@ -223,21 +230,29 @@ void scatter(
update_size *= us;
}
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]);
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates);
auto out_loc = elem_to_loc(j, update_shape, out.strides());
op(updates.data<InT>()[update_loc],
out.data<InT>() + out_offset + out_loc);
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
}

View File

@@ -2,39 +2,19 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"

View File

@@ -295,6 +295,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
@@ -337,6 +344,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -273,6 +284,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -398,6 +413,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -598,7 +617,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}

View File

@@ -2,14 +2,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>

View File

@@ -2,13 +2,38 @@
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size>
void _qmm(
T* result,
@@ -20,13 +45,12 @@ void _qmm(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -40,13 +64,25 @@ void _qmm(
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
@@ -67,13 +103,12 @@ void _qmm_t(
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
@@ -85,12 +120,26 @@ void _qmm_t(
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++;
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
@@ -102,6 +151,55 @@ void _qmm_t(
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
@@ -116,79 +214,29 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 3:
_qmm_dispatch_group<T, 3>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 4:
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 8:
_qmm_dispatch_group<T, 8>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
@@ -201,55 +249,61 @@ void _qmm_dispatch(
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
scales.data<float16_t>(),
biases.data<float16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
scales.data<bfloat16_t>(),
biases.data<bfloat16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
@@ -398,4 +452,114 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_);
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core

View File

@@ -120,48 +120,56 @@ struct MinReduce {
};
template <typename InT>
void reduce_dispatch_out(
void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
switch (rtype) {
case Reduce::And: {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
break;
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
case Reduce::Or: {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
break;
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
} else {
reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
}
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
} // namespace
void nd_loop(
@@ -190,46 +198,114 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_);
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
case uint8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
case uint16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_out<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_);
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
}
}

View File

@@ -34,7 +34,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -111,7 +111,8 @@ void sort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -123,14 +124,16 @@ void sort(const array& in, array& out, int axis) {
int axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out.data<T>() + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
}
@@ -160,11 +163,15 @@ void argsort(const array& in, array& out, int axis) {
int axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
@@ -192,7 +199,8 @@ void partition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -206,9 +214,11 @@ void partition(const array& in, array& out, int axis, int kth) {
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out.data<T>() + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
@@ -227,37 +237,49 @@ void argpartition(const array& in, array& out, int axis, int kth) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator md(idx_ptr, axis_stride, kth);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -2,7 +2,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {

View File

@@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
}
}
template <typename T, typename Op>
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename Op>
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
T* dst = out.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {

View File

@@ -4,6 +4,28 @@
namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(

View File

@@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int i = dims_;
int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i];
@@ -98,15 +102,41 @@ struct ContiguousIterator {
loc += strides_[i];
}
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
}
}
void reset() {
loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
dims_ = shape_.size() - 1;
pos_ = std::vector<int>(dims_ + 1, 0);
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
StrideT loc{0};
@@ -115,7 +145,6 @@ struct ContiguousIterator {
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
int dims_;
};
template <typename StrideT>
@@ -149,4 +178,13 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra;
}
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
} // namespace mlx::core

View File

@@ -14,20 +14,27 @@ function(make_jit_source SRC_FILE)
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
${SRC_FILE}
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h)
make_jit_source(
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(hadamard)
if(MLX_METAL_JIT)
@@ -99,6 +106,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
if(NOT MLX_METAL_PATH)

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit;
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
if (size == 0) {
@@ -205,7 +215,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)};
}
@@ -230,7 +242,11 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
@@ -246,15 +262,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
@@ -265,6 +275,15 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal {
@@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
@@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator
BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
std::mutex mutex_;

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -23,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
bool large,
int ndim,
int work_per_thread) {
std::ostringstream kname;
std::string kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
kname = "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
kname = (large ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
kname = (large ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
kname = (large ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
kname = "g";
if (ndim <= 3) {
kname << ndim;
kname += std::to_string(ndim);
} else {
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
concatenate(kname, "n", std::to_string(work_per_thread));
}
if (large) {
kname += "large";
}
break;
}
kname << "_" << op << type_to_name(a);
return kname.str();
concatenate(kname, "_", op, type_to_name(a));
return kname;
}
void binary_op_gpu_inplace(
@@ -82,19 +81,23 @@ void binary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX;
bool large = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread =
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
int work_per_thread;
if (bopt == BinaryOpType::General) {
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
@@ -111,6 +114,7 @@ void binary_op_gpu_inplace(
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (bopt == BinaryOpType::General) {
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -118,39 +122,33 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
compute_encoder.set_vector_bytes(shape, arg_idx++);
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder.set_bytes<int>(ndim, arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder.set_vector_bytes(strides_a, arg_idx++);
compute_encoder.set_vector_bytes(strides_b, arg_idx++);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -11,10 +12,12 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
inline void build_kernel(
std::ostream& os,
std::string& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
@@ -23,7 +26,8 @@ inline void build_kernel(
bool contiguous,
int ndim,
bool dynamic_dims,
bool use_big_index = false) {
bool use_big_index = false,
int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
@@ -38,8 +42,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
os += fmt::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments
for (auto& x : inputs) {
@@ -51,135 +55,203 @@ inline void build_kernel(
}
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
} else {
if (!is_scalar(x) && !contiguous) {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
}
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
os += fmt::format(
" constant const size_t* in_strides [[buffer({0})]],\n", cnt++);
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
os += fmt::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
os += fmt::format(
" constant const size_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
if (use_big_index) {
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
std::string idx_type = use_big_index ? "size_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
os += " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else {
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
}
os << std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
os += fmt::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
// Read the inputs in tmps
int nc_in_count = 0;
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");" << std::endl;
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
os += fmt::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
os += fmt::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
nc_inputs.push_back(x);
}
}
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_1<size_t, uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_2<size_t, {0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
"elem_to_loc_3<size_t, {0}>(pos, in_strides + {1});\n",
idx_type,
offset);
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
} else {
os += fmt::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os += " uint zpos = pos.z;\n";
if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
}
os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname);
if (dynamic_dims) {
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else {
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
}
}
os += " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os += fmt::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
os += fmt::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
os += fmt::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else {
x.primitive().print(os);
os << "()(";
std::ostringstream ss;
x.primitive().print(ss);
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os += fmt::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else {
os += fmt::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
}
}
os += " index++;\n }\n";
}
// Finish the kernel
os << "}" << std::endl;
os += "}\n";
if (cnt > 31) {
std::ostringstream msg;
@@ -202,13 +274,10 @@ void Compiled::eval_gpu(
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
auto lib = d.get_library(kernel_lib_, [&]() {
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -221,7 +290,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_big",
kernel_lib_ + "_contiguous_large",
inputs_,
outputs_,
tape_,
@@ -240,7 +309,23 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? 2 : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
}
build_kernel(
kernel,
@@ -251,14 +336,27 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
lib = d.get_library(kernel_lib_, kernel.str());
}
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ 2);
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, output_shape);
auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
@@ -306,13 +404,19 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool use_2d = false;
bool large;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
use_2d = (max_size > UINT32_MAX);
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
}
// Get the kernel from the lib
@@ -325,12 +429,13 @@ void Compiled::eval_gpu(
} else {
kernel_name += std::to_string(shape.size());
}
} else if (use_2d) {
kernel_name += "_big";
}
if (large) {
kernel_name += "_large";
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Put the inputs in
int cnt = 0;
@@ -351,8 +456,7 @@ void Compiled::eval_gpu(
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
compute_encoder.set_vector_bytes(in_strides, cnt++);
}
compiled_allocate_outputs(
@@ -365,36 +469,43 @@ void Compiled::eval_gpu(
// Put the output shape and strides in
if (!contiguous) {
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
compute_encoder.set_bytes(ndim, cnt++);
}
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
int pow2;
if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@@ -44,23 +44,24 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -72,7 +73,7 @@ void explicit_gemm_conv_ND_gpu(
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped};
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
@@ -122,23 +123,24 @@ void explicit_gemm_conv_group_ND_gpu(
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
@@ -155,22 +157,27 @@ void explicit_gemm_conv_group_ND_gpu(
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
}
@@ -232,7 +239,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -247,8 +254,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -347,7 +354,7 @@ void implicit_gemm_conv_2D_gpu(
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -363,11 +370,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -501,7 +508,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -518,17 +525,15 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
compute_encoder.set_vector_bytes(base_h, 6);
compute_encoder.set_vector_bytes(base_w, 7);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -617,18 +622,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
compute_encoder.set_bytes(C_c, 2);
compute_encoder.set_bytes(O_c, 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -645,18 +650,17 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -693,18 +697,17 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
@@ -747,10 +750,6 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
@@ -764,10 +763,13 @@ void conv_2D_gpu(
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
@@ -913,12 +915,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
}
// Clear copies
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
// Record copies
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -74,44 +74,46 @@ void copy_gpu_inplace(
};
auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size();
bool use_2d = out.data_size() > UINT32_MAX;
bool large;
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
auto& d = metal::device(s.device);
int work_per_thread = 1;
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else if (shape[ndim - 1] >= 4) {
work_per_thread = 4;
kname << "n4";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
switch (ctype) {
case CopyType::Scalar:
kernel_name = (large ? "s2" : "s");
break;
case CopyType::Vector:
kernel_name = (large ? "v2" : "v");
break;
case CopyType::General:
kernel_name = "g";
break;
case CopyType::GeneralGeneral:
kernel_name = "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
}
if (large) {
kernel_name += "large";
}
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype());
@@ -120,15 +122,16 @@ void copy_gpu_inplace(
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
compute_encoder.set_output_array(out, 1, out_offset);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2);
compute_encoder.set_vector_bytes(shape, ndim, 2);
}
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
compute_encoder.set_vector_bytes(strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) {
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
compute_encoder.set_vector_bytes(strides_out, ndim, 4);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -140,29 +143,27 @@ void copy_gpu_inplace(
int rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
compute_encoder.set_bytes(ndim, 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
@@ -194,26 +195,26 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool use_2d = out.data_size() > UINT32_MAX;
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -32,20 +32,18 @@ void CustomKernel::eval_gpu(
return copies.back();
}
};
std::vector<const array> checked_inputs;
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i];
@@ -55,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index);
compute_encoder.set_bytes(ndim, index);
index++;
}
}
@@ -74,15 +72,13 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -20,18 +20,21 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
auto get_metal_version() {
auto get_metal_version_ = []() {
if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
return MTL::LanguageVersion3_2;
} else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
return MTL::LanguageVersion3_1;
} else {
return MTL::LanguageVersion3_0;
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
}
auto load_device() {
@@ -121,33 +124,29 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
enc_->endEncoding();
enc_->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
enc_->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
@@ -156,55 +155,49 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs.insert(buf);
next_outputs_.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
}
next_outputs_.clear();
}
void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatch_threads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreads(grid_dims, group_dims);
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
}
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -219,69 +212,134 @@ void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
return get_stream_(index).buffer_ops;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
get_stream_(index).buffer_ops++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
stream.buffer->retain();
}
return bit->second.second;
return stream.buffer;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
void Device::end_encoding(int index) {
encoder_map_.erase(index);
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
return *(eit->second);
return *stream.encoder;
}
void Device::register_library(
@@ -293,20 +351,7 @@ void Device::register_library(
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
MTL::Library* Device::build_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
@@ -322,26 +367,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -465,68 +491,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
MTL::Library* Device::get_library_(const std::string& name) {
std::shared_lock lock(library_mtx_);
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
const std::function<std::string(void)>& builder) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto mtl_lib = build_library_(builder());
library_map_.insert({name, mtl_lib});
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -547,34 +537,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Single writer allowed
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
auto pool = new_scoped_memory_pool();
// Pull kernel from library
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
MTL::ComputePipelineState* Device::get_kernel(
@@ -583,16 +594,34 @@ MTL::ComputePipelineState* Device::get_kernel(
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
}
Device& device(mlx::core::Device) {

View File

@@ -7,6 +7,7 @@
#include <filesystem>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -44,43 +45,114 @@ struct CommandEncoder {
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
enc.concurrent_ = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
enc.concurrent_ = false;
enc.prev_outputs_.insert(
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs_.clear();
}
private:
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder();
private:
void maybe_split();
// Inputs to all kernels in the encoder including temporaries
std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
return all_outputs_;
};
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
bool concurrent_{false};
std::unordered_set<MTL::Resource*> prev_outputs_;
std::unordered_set<MTL::Resource*> next_outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
};
class Device {
@@ -94,6 +166,10 @@ class Device {
return device_;
};
const std::string& get_architecture() {
return arch_;
}
void new_queue(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
@@ -114,29 +190,9 @@ class Device {
}
}
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
const std::function<std::string(void)>& builder);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
@@ -155,11 +211,20 @@ class Device {
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
void set_residency_set(const MTL::ResidencySet* residency_set);
private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Library* get_library_(const std::string& name);
MTL::Library* build_library_(const std::string& source_string);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
@@ -181,13 +246,23 @@ class Device {
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::ComputePipelineState* get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
};
Device& device(mlx::core::Device);

View File

@@ -27,4 +27,9 @@ void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
}
bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
value();
}
} // namespace mlx::core

View File

@@ -575,8 +575,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -700,7 +699,7 @@ void fft_op(
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
@@ -712,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder.set_bytes(n, 4);
compute_encoder.set_bytes(plan.bluestein_n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
@@ -724,25 +723,25 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
compute_encoder.set_bytes(n, 5);
compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder.set_bytes(plan.rader_n, 7);
} else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
compute_encoder.set_bytes(four_step_params.n1, 2);
compute_encoder.set_bytes(four_step_params.n2, 3);
compute_encoder.set_bytes(total_batch_size, 4);
} else {
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
compute_encoder.set_bytes(n, 2);
compute_encoder.set_bytes(total_batch_size, 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
}
void fft_op(
@@ -785,8 +784,7 @@ void nd_fft_op(
}
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
d.add_temporaries(std::move(temp_arrs), s.index);
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
@@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
@@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
@@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
@@ -171,33 +164,17 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -53,28 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large = large_index || large_src || large_out;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string kernel_name = fmt::format(
"gather{0}{1}_{2}_{3}_{4}",
type_to_name(out),
idx_type_name,
nidx,
idx_ndim,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
kernel_source += fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
@@ -82,13 +85,14 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
idx_ndim,
large ? "size_t" : "uint");
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
@@ -114,17 +118,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
// Set all the buffers
@@ -132,21 +136,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
compute_encoder.set_vector_bytes(src.shape(), 2);
compute_encoder.set_vector_bytes(src.strides(), 3);
compute_encoder.set_bytes(ndim, 4);
compute_encoder.set_vector_bytes(slice_sizes_, 5);
compute_encoder.set_vector_bytes(axes_, 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
compute_encoder.set_vector_bytes(idx_shapes, 7);
compute_encoder.set_vector_bytes(idx_strides, 8);
compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder.set_bytes(idx_ndim, 10);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -154,7 +157,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -173,12 +176,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
CopyType copy_type;
if (inputs[0].data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (inputs[0].flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(inputs[0], out, copy_type);
auto& upd = inputs.back();
// Empty update
if (inputs.back().size() == 0) {
if (upd.size() == 0) {
return;
}
@@ -187,23 +198,22 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
size_t idx_size = nidx ? inputs[1].size() : 1;
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
auto idx_to_out = idx_size / out.size();
int nwork;
if (idx_ndim <= 1 || idx_to_out < 1) {
nwork = 1;
} else if (idx_to_out <= 4) {
nwork = 4;
} else if (idx_to_out < 16) {
nwork = 8;
} else if (idx_to_out < 32) {
nwork = 16;
} else {
nwork = 32;
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
@@ -223,24 +233,25 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "min";
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
type_to_name(out),
idx_type_name,
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "size_t" : "uint");
std::string lib_name = kernel_name;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
@@ -264,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
kernel_source += fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
@@ -276,126 +287,105 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
idx_arr,
upd_contig,
nwork,
large ? "size_t" : "uint");
return kernel_source;
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Set all the buffers
compute_encoder.set_input_array(upd, 1);
compute_encoder.set_output_array(out, 2);
// Set update info
uint upd_ndim = upd.ndim();
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 3);
compute_encoder.set_bytes(stride_, 4);
} else {
compute_encoder.set_vector_bytes(upd.shape(), 3);
compute_encoder.set_vector_bytes(upd.strides(), 4);
}
compute_encoder.set_bytes(upd_ndim, 5);
compute_encoder.set_bytes(upd_size, 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder.set_bytes(shape_, 7);
compute_encoder.set_bytes(stride_, 8);
} else {
compute_encoder.set_vector_bytes(out.shape(), 7);
compute_encoder.set_vector_bytes(out.strides(), 8);
}
compute_encoder.set_bytes(out_ndim, 9);
compute_encoder.set_vector_bytes(axes_, 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
compute_encoder.set_vector_bytes(idx_shapes, 11);
compute_encoder.set_vector_bytes(idx_strides, 12);
compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder.set_bytes(idx_size, 15);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
auto grid_y = (nthreads / upd_size);
grid_y = (grid_y + nwork - 1) / nwork;
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
[[kernel]] void gather{0}_{3}_{6}_{7}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
@@ -11,14 +11,15 @@ constexpr std::string_view gather_kernels = R"(
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
{4}
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
return gather_impl<{1}, {2}, {3}, {6}, {7}>(
src,
out,
src_shape,
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}}
[[kernel]] void scatter{0}_{4}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
updates,
out,
upd_shape,
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
out_strides,
out_ndim,
axes,
idx_size,
idxs,
gid);
}}

View File

@@ -1,26 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -1,10 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
@@ -25,48 +23,50 @@ MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
kernel_source << get_template_definition(
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source +=
get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
void append_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@@ -75,27 +75,24 @@ void add_binary_kernels(
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"gn", "binary_g"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
}};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
kernel_source +=
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
}
MTL::ComputePipelineState* get_binary_kernel(
@@ -105,13 +102,13 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -122,14 +119,12 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -139,28 +134,31 @@ MTL::ComputePipelineState* get_ternary_kernel(
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
auto lib = d.get_library(lib_name, [&]() {
auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
}};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
kernel_source +=
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -170,36 +168,45 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source
<< metal::utils() << metal::copy()
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source +=
get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -209,8 +216,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
@@ -218,8 +224,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -232,22 +238,29 @@ MTL::ComputePipelineState* get_scan_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
auto lib = d.get_library(lib_name, [&]() {
auto out_type = get_type_string(out.dtype());
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
op[3] = toupper(op[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::scan();
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
{"contig_", "contiguous_scan"},
{"strided_", "strided_scan"},
}};
for (auto& [prefix, kernel] : scan_kernels) {
kernel_source << get_template_definition(
prefix + lib_name,
kernel,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op,
in.itemsize() <= 4 ? 4 : 2,
inclusive,
reverse);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -259,8 +272,7 @@ MTL::ComputePipelineState* get_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
@@ -285,8 +297,8 @@ MTL::ComputePipelineState* get_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -298,8 +310,7 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
@@ -316,27 +327,28 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str());
}
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::reduce();
kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -345,32 +357,32 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& in,
const array& out,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_t + ">";
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
if (bm >= 0) {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
} else if (ndim >= 0) {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
} else {
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
kernel_source += get_template_definition(
kernel_name, func_name, in_t, out_t, op, idx_t);
}
lib = d.get_library(kernel_name, kernel_source.str());
}
return kernel_source;
});
auto st = d.get_kernel(kernel_name, lib);
return st;
}
@@ -389,8 +401,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
@@ -405,8 +416,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -425,8 +436,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
@@ -444,8 +454,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -456,19 +466,19 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -488,8 +498,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -513,8 +522,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -533,8 +542,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -556,8 +564,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -573,8 +581,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
@@ -588,8 +595,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -603,8 +610,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
@@ -617,8 +623,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -629,13 +635,12 @@ MTL::ComputePipelineState* get_fft_kernel(
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
std::string kernel_string;
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -644,13 +649,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

View File

@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
@@ -78,15 +79,18 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
const std::string& func_name,
const std::string& op_name,
const Dtype& out_type);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& in,
const array& out,
const Dtype& in_type,
const Dtype& out_type,
const std::string& idx_t,
int ndim = -1,
int bm = -1,
int bn = -1);
@@ -208,10 +212,10 @@ get_template_definition(std::string name, std::string func, Args... args) {
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
return fmt::format(
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
name,
s.str());
}
} // namespace mlx::core

View File

@@ -1,13 +1,27 @@
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h)
set(BASE_HEADERS
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
-I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
@@ -30,8 +44,7 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
steel/defines.h steel/gemm/transforms.h steel/utils.h)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
set(STEEL_HEADERS
steel/defines.h
@@ -49,7 +62,27 @@ set(STEEL_HEADERS
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h)
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \

View File

@@ -2,8 +2,6 @@
#pragma once
#include "mlx/backend/metal/kernels/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
@@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(

View File

@@ -77,12 +77,12 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -91,13 +91,13 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -106,14 +106,18 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -124,13 +128,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;

View File

@@ -9,19 +9,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -217,7 +217,7 @@ struct Power {
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;

View File

@@ -99,14 +99,14 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
@@ -116,15 +116,15 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
@@ -134,16 +134,20 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
@@ -155,13 +159,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];

View File

@@ -7,19 +7,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \

View File

@@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const
@@ -113,6 +113,7 @@ template <typename T, int N>
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
out += ws_ * kernel_stride;
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
@@ -126,7 +127,6 @@ template <typename T, int N>
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
}

View File

@@ -42,36 +42,36 @@ template <typename T, typename U>
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
IdxT dst_idx =
index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(
auto src_idx = elem_to_loc<int64_t, IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) {
int64_t dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx =
index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
return;
}
auto xshape = src_shape[ndim - 1];
int64_t dst_idx =
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -105,36 +104,36 @@ template <typename T, typename U>
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
auto src_idx = elem_to_loc_1<int64_t, int>(index, src_stride);
auto dst_idx = elem_to_loc_1<int64_t, int>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
auto src_idx = elem_to_loc_2<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
auto src_idx = elem_to_loc_3<int64_t, IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3<int64_t, IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int N = 1>
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<int64_t, IdxT>(
{N * index.x, index.y, index.z},
src_shape,
src_strides,
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
dst[idx.y] = static_cast<U>(src[idx.x]);
return;
}
auto src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1];
IdxT src_xstride = src_strides[ndim - 1];
IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -2,24 +2,27 @@
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -16,34 +16,36 @@ METAL_FUNC void gather_impl(
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
size_t src_idx = 0;
LocT src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
LocT idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
} else {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc<size_t, LocT>(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
}
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
auto src_offset =
elem_to_loc<size_t, LocT>(index.z, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.z;
LocT out_idx = index.z;
if (IDX_NDIM == 1) {
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) {
out_idx +=
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
}
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -3,8 +3,6 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
@@ -912,4 +910,4 @@ template <
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -4,8 +4,6 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h"

View File

@@ -9,11 +9,12 @@ struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const constant bool* row_contiguous;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@@ -3,8 +3,6 @@
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;

View File

@@ -6,12 +6,6 @@
using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
@@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal
#pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
#endif
#include "mlx/backend/metal/kernels/bf16_math.h"
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
}

View File

@@ -0,0 +1,16 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,67 +5,119 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_types(name, group_size, bits) \
instantiate_quantized(name, float, group_size, bits) \
instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_quantized(name, bfloat16_t, group_size, bits)
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched)
#define instantiate_quantized_groups(name, bits) \
instantiate_quantized_types(name, 128, bits) \
instantiate_quantized_types(name, 64, bits) \
instantiate_quantized_types(name, 32, bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)
#define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, float, group_size, bits, false) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched)
#define instantiate_quantized_groups_aligned(name, bits) \
instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_quantized_types_aligned(name, 32, bits)
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
D, \
batched)
#define instantiate_quantized_all_aligned(name) \
instantiate_quantized_groups_aligned(name, 2) \
instantiate_quantized_groups_aligned(name, 4) \
instantiate_quantized_groups_aligned(name, 8) \
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
name, \
type, \
group_size, \
bits, \
split_k)
instantiate_quantized_all_aligned(qmm_t)
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbitsc(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
constant const bool& odd,
constant const uint& bytes_per_key,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbits(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
constant const bool& odd,
constant const uint& bytes_per_key,
constant const int& ndim,
constant const int* key_shape,
constant const size_t* key_strides,

View File

@@ -10,178 +10,156 @@
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_init_reduce(name, tname, type, op) \
instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
instantiate_init_reduce(and, bool_, bool, And)
instantiate_init_reduce(or, bool_, bool, Or)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) \
inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_init_sum_prod(name, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
instantiate_init_sum_prod(sum, Sum)
instantiate_init_sum_prod(prod, Prod)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_init_min_max(name, op) \
instantiate_init_reduce(name, bool_, bool, op) \
instantiate_init_reduce(name, int8, int8_t, op) \
instantiate_init_reduce(name, int16, int16_t, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, uint8, uint8_t, op) \
instantiate_init_reduce(name, uint16, uint16_t, op) \
instantiate_init_reduce(name, uint32, uint32_t, op) \
instantiate_init_reduce(name, uint64, uint64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
instantiate_init_min_max(min, Min)
instantiate_init_min_max(max, Max)
#define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \
all_reduce, \
itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, size_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, size_t, dim)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, size_t, dim, bm, bn)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, size_t, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 5) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
instantiate_col_reduce_looped(name, itype, otype, op, 4)
instantiate_col_reduce_looped(name, itype, otype, op, 5)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_general(name##tname, type, type, op<type>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, size_t, dim)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, size_t, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 5) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 5) \
instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \
itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_reduce_functions(name, tname, itype, otype, op) \
instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
#define instantiate_and_or(name, op) \
instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
instantiate_and_or(and, And)
instantiate_and_or(or, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_sum_prod(sum, Sum)
instantiate_sum_prod(prod, Prod)
#define instantiate_min_max(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_min_max(min, Min)
instantiate_min_max(max, Max)
// clang-format on

Some files were not shown because too many files have changed in this diff Show More