Compare commits

...

129 Commits

Author SHA1 Message Date
Awni Hannun
1086dc4db0 patch (#1320) 2024-08-12 16:13:33 -07:00
Brian Keene
19fb69e2ed Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Awni Hannun
9231617eb3 Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317 CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95 Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
Angelos Katharopoulos
eb8819e91e Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08 Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25 Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9 Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1 add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501 fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
6c8dd307eb faster group norm (#1304) 2024-08-01 12:49:23 -07:00
Awni Hannun
43ffdab172 fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333 Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0 Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Atakan Tekparmak
6e06e3a904 feat: Added "tanh" option to GELU approximation (#1268) 2024-07-28 09:07:56 +02:00
Yaroslav
8cfb9fc0b8 Update requirements.txt (#1291) 2024-07-26 12:59:52 -07:00
Awni Hannun
7b456fd2c0 Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Awni Hannun
e9e53856d2 patch bump (#1287) 2024-07-25 11:42:09 -07:00
Anton Belov
5029894662 [Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42 Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50 Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7 Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae some fixes (#1281) 2024-07-23 11:49:05 -07:00
toji
6768c6a54a Adding missing type hints (#1243)
* added type hints for `run`, `tree_map` and `tree_map_with_path`

* fix lint

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Tim Gymnich
6307d166eb Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7 Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7 Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Awni Hannun
8c01a7893b minor fix in optimizer + docs (#1264) 2024-07-12 12:18:02 -07:00
Awni Hannun
218047c75a docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Alex Barron
d0da74209b version bump (#1260) 2024-07-11 11:17:55 -07:00
Angelos Katharopoulos
5c1fa64fb0 Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82 Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63 add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195 CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Awni Hannun
d6383a1c6a version bump (#1239) 2024-06-27 10:43:13 -07:00
Angelos Katharopoulos
b05bcfd27f Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62 Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1 fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8 Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439 Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
David Koski
4eef1e8a3e fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06 Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
Awni Hannun
af9079cc1f version bump (#1212) 2024-06-14 11:28:51 -07:00
Jagrit Digani
2d6cd47713 Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea smaller CPU binary (#1203)
* smaller CPU binary

* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35 Build for macOS 15 (#1208)
* Build for macos 15

* metal32 as well

* comment

---------

Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Awni Hannun
e84ba8056d only allow openmpi (#1209) 2024-06-13 12:14:44 -07:00
Fangjun Kuang
f20e97b092 minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a Fix kernel deps to reduce build times (#1205) 2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29 Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Awni Hannun
709ccc6800 install mpi for release build (#1199) 2024-06-10 10:09:32 -07:00
Awni Hannun
cf236fc390 version (#1191) 2024-06-06 17:16:40 -07:00
Alex Barron
27d70c7d9d Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409 Add docstring for scatter (#1189)
* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893 Fix the hard-shrink test (#1185) 2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4 Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
nicolov
81def6ac76 Fix benchmark (#1175) 2024-06-04 07:50:46 -07:00
Angelos Katharopoulos
3de8ce3f3c In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24 Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30 Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Dominik Schlösser
3576b547c5 Doc error for default for scale in SinusoidalPositionalEncoding (#1174) 2024-06-02 13:42:45 -07:00
Awni Hannun
079882495d version bump (#1172) 2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46 Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef version bump (#1154) 2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032 Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67 Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6 More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336 Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863 list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2 add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2 Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7 JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36 Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1 Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf Update README.md (#1136) 2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380 Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Awni Hannun
6a9b584f3d patch bump (#1131) 2024-05-16 20:51:33 -07:00
Awni Hannun
81dd33af66 allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Awni Hannun
8b76571896 Fix extensions (#1126)
* fix extensions

* title

* enable circle

* fix nanobind tag

* fix bug in doc

* try to fix config

* typo
2024-05-16 15:36:25 -07:00
Angelos Katharopoulos
e78a6518fa Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01 Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jacket
c417e42116 [Fix] minor typo in default argument for argpartition's "axis" parameter (#1125)
According to the document, argpartition's axis parameter can be None, but due to a previous typo it can't really accepts a None value.
2024-05-15 15:25:25 -07:00
Jagrit Digani
358e1fd6ab Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
631dfbe673 fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Cheng
56a4eaed72 Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7 Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6 Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
5be5daa6ef Use compiled function in Sigmoid module (#1116) 2024-05-14 06:25:57 -07:00
Cheng
60cb11764e Use correct module type in quantized.py (#1115) 2024-05-14 06:25:42 -07:00
Cheng
cbd5445ea7 The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Cheng
2c7e9b5158 Add missing docs for some ops (#1110) 2024-05-14 06:09:05 -07:00
Mike Drob
2263e4b279 Experiment with medium machines for CI (#1000) 2024-05-13 19:40:19 -07:00
Awni Hannun
863039da4c Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111 No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Ravindra R. Jaju
e7f9710499 Fix typo in a variable name in example code. (#1104)
* Fix typo in a variable name in example code.

* Rename df2dx2 to d2fdx2 - the appropriate naming for the second derivative

* Update CONTRIBUTING.md - add needed python packages, and a virtual-env hint

* Revert "Fix typo in a variable name in example code."

This reverts commit bc10a17534.

* Rename df2dx2 to d2fdx2
2024-05-13 06:04:23 -07:00
Max-Heinrich Laves
ff4223904d Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6 improve error messaging in eval (#1101) 2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0 Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8bd6bfa4b5 version (#1099) 2024-05-09 17:52:39 -07:00
Awni Hannun
8b1906abd0 Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605 Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Awni Hannun
b21242faf1 Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4 Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66 Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
313 changed files with 28719 additions and 12045 deletions

View File

@@ -49,11 +49,6 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
@@ -69,13 +64,14 @@ jobs:
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.8
brew install openmpi
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
@@ -101,11 +97,14 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
@@ -117,7 +116,13 @@ jobs:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
build_release:
parameters:
@@ -132,13 +137,14 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip

View File

@@ -17,4 +17,4 @@ jobs:
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
pre-commit run --all-files

View File

@@ -10,12 +10,14 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -15,12 +15,16 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.12.2)
set(MLX_VERSION 0.16.3)
endif()
# --------------------- Processor tests -------------------------
@@ -79,22 +83,21 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
else()
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
# Get the metal version
execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION
COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -104,55 +107,85 @@ elseif (MLX_BUILD_METAL)
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
if (MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
find_package(MPI)
if (MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET
)
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING
"MPI found but mpirun is not available. Building without MPI."
)
else()
set(MPI_FOUND FALSE)
message(
WARNING
"MPI which is not OpenMPI found. Building without MPI."
)
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -164,6 +197,14 @@ target_include_directories(
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

View File

@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
return torch.nn.functional.mish(y)
y = torch.nn.functional.mish(y)
sync_if_needed(x)
@@ -283,6 +283,14 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
@@ -446,5 +454,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args):

View File

@@ -9,7 +9,6 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@@ -51,7 +50,6 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
@@ -53,11 +53,12 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
@@ -67,15 +68,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
@@ -84,7 +85,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
@@ -95,35 +96,40 @@ if __name__ == "__main__":
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,84 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -3,6 +3,8 @@
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
return bandwidths
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__":

View File

@@ -0,0 +1,70 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -0,0 +1,62 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,3 +1,4 @@
sphinx
breathe
sphinx-book-theme
mlx

View File

@@ -83,3 +83,15 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -1,5 +1,5 @@
Developer Documentation
=======================
Custom Extensions in MLX
========================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
@@ -486,15 +486,14 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -503,11 +502,11 @@ below.
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -531,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
@@ -825,7 +824,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
print(f"c correct: {mx.all(c == 6.0).item()}")
Output:

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE
We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as `mnist`.
we will import as ``mnist``.
.. code-block:: python

View File

@@ -43,6 +43,7 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
.. toctree::
@@ -69,6 +70,7 @@ are the CPU and GPU.
python/metal
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::

View File

@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
For developing use an editable install:
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
To make sure the install is working run the tests with:
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
Run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your IDE:
Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
@@ -153,11 +153,18 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
@@ -176,10 +183,37 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~

View File

@@ -24,6 +24,7 @@ Array
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
@@ -57,3 +58,4 @@ Array
array.transpose
array.T
array.var
array.view

View File

@@ -0,0 +1,19 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather

View File

@@ -8,5 +8,10 @@ Linear Algebra
.. autosummary::
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
qr
svd

View File

@@ -17,6 +17,8 @@ simple functions.
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@@ -29,6 +31,7 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@@ -15,15 +15,21 @@ Layers
BatchNorm
Conv1d
Conv2d
Conv3d
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LSTM
MaxPool1d
@@ -35,13 +41,19 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@@ -10,6 +10,7 @@ Operations
abs
add
addmm
all
allclose
any
@@ -19,12 +20,14 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
@@ -32,11 +35,12 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d
@@ -53,6 +57,8 @@ Operations
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
@@ -64,8 +70,11 @@ Operations
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
hadamard_transform
identity
inner
isclose
@@ -73,6 +82,7 @@ Operations
isnan
isneginf
isposinf
issubdtype
left_shift
less
less_equal
@@ -96,6 +106,7 @@ Operations
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
@@ -103,11 +114,13 @@ Operations
outer
partition
pad
power
prod
quantize
quantized_matmul
radians
reciprocal
remainder
repeat
reshape
right_shift
@@ -141,11 +154,13 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril
triu
var
view
where
zeros
zeros_like

View File

@@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer

View File

@@ -44,3 +44,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split
truncated_normal
uniform
laplace

View File

@@ -10,6 +10,7 @@ Transforms
eval
compile
custom_function
disable_compile
enable_compile
grad

View File

@@ -0,0 +1,166 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -3,7 +3,11 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
.. code-block:: python

View File

@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -89,8 +89,8 @@ void automatic_differentiation() {
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
}
int main() {

View File

@@ -1,5 +1,5 @@
## Build the extensions
## Build
```
pip install -e .
@@ -16,3 +16,9 @@ And then run:
```
python setup.py build_ext -j8 --inplace
```
## Test
```
python test.py
```

View File

@@ -249,15 +249,14 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -266,11 +265,11 @@ void Axpby::eval_gpu(
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -296,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available

View File

@@ -2,4 +2,4 @@
import mlx.core as mx
from .mlx_sample_extensions import *
from ._ext import axpby

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
mlx>=0.16.2
nanobind==2.0

View File

@@ -0,0 +1,10 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -6,6 +6,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
@@ -19,11 +20,17 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if (MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE

View File

@@ -17,6 +17,10 @@ bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -102,7 +106,7 @@ void array::eval() {
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -171,10 +175,11 @@ array::~array() {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
@@ -206,7 +211,7 @@ void array::ArrayDesc::init() {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
for (const auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
@@ -231,7 +236,7 @@ array::ArrayDesc::ArrayDesc(
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// that may also destroy their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we

View File

@@ -73,32 +73,32 @@ class array {
this->array_desc_ = other.array_desc_;
}
return *this;
};
}
/** The size of the array's datatype in bytes. */
size_t itemsize() const {
return size_of(dtype());
};
}
/** The number of elements in the array. */
size_t size() const {
return array_desc_->size;
};
}
/** The number of bytes in the array. */
size_t nbytes() const {
return size() * itemsize();
};
}
/** The number of dimensions of the array. */
size_t ndim() const {
return array_desc_->shape.size();
};
}
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
return array_desc_->shape;
};
}
/**
* Get the size of the corresponding dimension.
@@ -107,12 +107,12 @@ class array {
* bounds checking. */
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
};
}
/** The strides of the array. */
const std::vector<size_t>& strides() const {
return array_desc_->strides;
};
}
/**
* Get the stride of the corresponding dimension.
@@ -121,12 +121,12 @@ class array {
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
}
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
};
}
/** Evaluate the array. */
void eval();
@@ -160,10 +160,10 @@ class array {
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx;
};
}
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b);
};
}
private:
const array& arr;
@@ -209,7 +209,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {};
: buffer(buffer), d(d) {}
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -230,22 +230,22 @@ class array {
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
};
}
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
};
}
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
};
}
/** The array's inputs. */
const std::vector<array>& inputs() const {
return array_desc_->inputs;
};
}
std::vector<array>& inputs() {
return array_desc_->inputs;
@@ -259,12 +259,12 @@ class array {
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
};
}
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
}
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
@@ -281,7 +281,7 @@ class array {
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
};
}
/** Detach the array from the graph. */
void detach();
@@ -289,19 +289,19 @@ class array {
/** Get the Flags bit-field. */
const Flags& flags() const {
return array_desc_->flags;
};
}
/** The size (in elements) of the underlying buffer the array points to. */
size_t data_size() const {
return array_desc_->data_size;
};
}
allocator::Buffer& buffer() {
return array_desc_->data->buffer;
};
}
const allocator::Buffer& buffer() const {
return array_desc_->data->buffer;
};
}
// Return a copy of the shared pointer
// to the array::Data struct
@@ -312,19 +312,20 @@ class array {
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
};
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
};
}
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
const Status status() const {
Status status() const {
return array_desc_->status;
}

View File

@@ -1,9 +1,9 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"

View File

@@ -2,8 +2,7 @@
#include <cassert>
#include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include <Accelerate/Accelerate.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"

View File

@@ -3,8 +3,7 @@
#include <cassert>
#include <cmath>
#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include <Accelerate/Accelerate.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
@@ -32,12 +31,12 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
@@ -47,8 +46,11 @@ DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
@@ -78,6 +80,7 @@ DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -99,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -114,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -129,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x + y; });
eval(inputs, out);
}
}
@@ -193,6 +196,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -264,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -277,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -292,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x / y; });
eval(inputs, out);
}
}
@@ -303,12 +326,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
eval(inputs, out);
}
}
@@ -370,12 +389,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
eval(inputs, out);
}
}
@@ -385,7 +400,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -400,7 +415,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x * y; });
eval(inputs, out);
}
}
@@ -411,7 +426,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
unary(in, out, [](auto x) { return -x; });
eval(inputs, out);
}
}
@@ -498,7 +513,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
eval(inputs, out);
}
}
@@ -524,7 +539,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -542,7 +557,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -554,7 +569,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
},
UseDefaultBinaryOp());
} else {
binary(a, b, out, [](auto x, auto y) { return x - y; });
eval(inputs, out);
}
}

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"

View File

@@ -3,7 +3,10 @@
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
@@ -53,25 +56,26 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
return (*(simd_float16*)&epart) * x;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
@@ -107,53 +111,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
};
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
@@ -170,7 +127,7 @@ struct NeonFp16SimdOps {
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
};
}
VT exp(VT x) {
return neon_fast_exp(x);
@@ -201,6 +158,55 @@ struct NeonFp16SimdOps {
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -362,12 +368,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <vecLib/BNNS/bnns.h>
#include <Accelerate/Accelerate.h>
#include "mlx/dtype.h"
namespace mlx::core {

View File

@@ -37,16 +37,20 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
@@ -55,6 +59,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
}
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -293,4 +307,25 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -1,6 +1,8 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"

View File

@@ -0,0 +1,101 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -0,0 +1,304 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(

View File

@@ -111,13 +111,17 @@ void slow_conv_2D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int C = in.shape(3); // In channels
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
const int groups = C / wt.shape(3);
const int C_per_group = wt.shape(3);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2];
@@ -141,33 +145,35 @@ void slow_conv_2D(
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
@@ -219,41 +225,43 @@ void slow_conv_2D(
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;
@@ -310,6 +318,296 @@ void slow_conv_2D(
} // n
}
template <typename T>
void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
const int oD = out.shape(1); // Output spatial dim
const int oH = out.shape(2); // Output spatial dim
const int oW = out.shape(3); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(4); // In channels
const int wD = wt.shape(1); // Weight spatial dim
const int wH = wt.shape(2); // Weight spatial dim
const int wW = wt.shape(3); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_D = in.strides()[1];
const size_t in_stride_H = in.strides()[2];
const size_t in_stride_W = in.strides()[3];
const size_t in_stride_C = in.strides()[4];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_D = wt.strides()[1];
const size_t wt_stride_H = wt.strides()[2];
const size_t wt_stride_W = wt.strides()[3];
const size_t wt_stride_C = wt.strides()[4];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_D = out.strides()[1];
const size_t out_stride_H = out.strides()[2];
const size_t out_stride_W = out.strides()[3];
const size_t out_stride_O = out.strides()[4];
bool is_idil_one =
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
const T* wt_ptr_pt =
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
std::vector<int> base_d(f_out_jump_d);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
wd_base++;
id_loop += jump_d;
}
base_d[i] = wd_base;
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // iD, ih, iw check
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: od might put us out of bounds
for (int od = oD_border_0; od < oD_border_1; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 2: od in bounds
for (int od = oD_border_1; od < oD_border_2; ++od) {
// Case 2.1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case 2.2.1: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.2: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.3: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 3: od might put us out of bounds
for (int od = oD_border_2; od < oD_border_3; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
@@ -358,6 +656,30 @@ void dispatch_slow_conv_2D(
}
}
void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_3D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_3D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_3D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
@@ -582,6 +904,131 @@ void explicit_gemm_conv_2D_cpu(
}
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}
for (const auto& w : wDim) {
strided_reshape[1] *= w;
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
@@ -617,6 +1064,19 @@ void conv_2D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
@@ -625,8 +1085,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& wt = inputs[1];
// 3D convolution
if (in.ndim() == (3 + 2)) {
return conv_3D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 2D convolution
if (in.ndim() == (2 + 2)) {
else if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,

View File

@@ -4,6 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -142,29 +143,31 @@ void copy_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
@@ -195,10 +198,10 @@ inline void copy_general_general_dims(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
int64_t i_offset,
int64_t o_offset) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
int axis = data_shape.size() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -209,7 +212,7 @@ inline void copy_general_general_dims(
o_offset += stride_dst;
}
} else {
int axis = src.ndim() - 1;
int axis = data_shape.size() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -230,38 +233,76 @@ void copy_general_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) {
int64_t i_offset,
int64_t o_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
switch (new_shape.size()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
}
int size = std::accumulate(
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
src_offset,
dst_offset);
}
}
@@ -444,8 +485,17 @@ void copy_inplace(
}
}
template <>
void copy_inplace<int64_t>(
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
@@ -453,24 +503,6 @@ void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
CopyType ctype);
} // namespace mlx::core

View File

@@ -5,7 +5,6 @@
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
@@ -34,6 +33,7 @@ DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
@@ -42,15 +42,17 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
@@ -66,6 +68,7 @@ DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
@@ -110,6 +113,7 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
namespace {

View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,105 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -2,7 +2,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
@@ -11,9 +10,106 @@
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
namespace mlx::core {
void inverse_impl(const array& a, array& inv) {
void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
@@ -25,63 +121,11 @@ void inverse_impl(const array& a, array& inv) {
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
}
}
}
@@ -90,15 +134,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output);
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::inv(a, stream())}, {ax}};
inverse_impl(inputs[0], output, tri_, upper_);
}
} // namespace mlx::core

View File

@@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}

View File

@@ -17,24 +17,25 @@ namespace mlx::core {
namespace {
template <typename T>
template <typename T, typename mask_t>
inline void mask_matrix(
T* data,
const bool* mask,
const mask_t* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
const size_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
if (do_mask != 1) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
@@ -43,7 +44,11 @@ inline void mask_matrix(
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
if constexpr (std::is_same_v<mask_t, bool>) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
} else {
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
}
}
}
}
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& out_mask = inputs[2];
auto check_transpose = [](const array& arr, bool do_copy) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
const bool* mask_ptr = mask.data<bool>() +
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
return mask_matrix(
data,
mask_ptr,
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str);
if (mask.dtype() == bool_) {
return mask_matrix(
data,
mask.data<bool>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
} else {
return mask_matrix(
data,
mask.data<float>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
}
};
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[3];
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[4];
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
@@ -186,14 +209,16 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
);
// Zero out blocks in out
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
}
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -277,4 +302,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
}
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -108,133 +108,146 @@ struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
};
}
uint8_t operator()(uint8_t x) {
return x;
};
}
uint16_t operator()(uint16_t x) {
return x;
};
}
uint32_t operator()(uint32_t x) {
return x;
};
}
uint64_t operator()(uint64_t x) {
return x;
};
}
bool operator()(bool x) {
return x;
};
}
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
};
}
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
};
}
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
};
}
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
};
}
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
};
}
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
}
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
};
}
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
};
}
int8_t operator()(int8_t x) {
return x;
};
}
int16_t operator()(int16_t x) {
return x;
};
}
int32_t operator()(int32_t x) {
return x;
};
}
int64_t operator()(int64_t x) {
return x;
};
}
uint8_t operator()(uint8_t x) {
return x;
};
}
uint16_t operator()(uint16_t x) {
return x;
};
}
uint32_t operator()(uint32_t x) {
return x;
};
}
uint64_t operator()(uint64_t x) {
return x;
};
}
bool operator()(bool x) {
return x;
};
}
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
};
}
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
};
}
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
};
}
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
};
}
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
};
}
complex64_t operator()(complex64_t x) {
return std::exp(x);
@@ -245,83 +258,83 @@ struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
};
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
};
}
int8_t operator()(int8_t x) {
return x;
};
}
int16_t operator()(int16_t x) {
return x;
};
}
int32_t operator()(int32_t x) {
return x;
};
}
int64_t operator()(int64_t x) {
return x;
};
}
uint8_t operator()(uint8_t x) {
return x;
};
}
uint16_t operator()(uint16_t x) {
return x;
};
}
uint32_t operator()(uint32_t x) {
return x;
};
}
uint64_t operator()(uint64_t x) {
return x;
};
}
bool operator()(bool x) {
return x;
};
}
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
};
}
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
};
}
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
};
}
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
}
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
}
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
}
};
struct Round {
@@ -366,49 +379,49 @@ struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
};
}
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
};
}
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
}
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
};
}
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
};
}
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
};
}
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
};
}
};
struct Add {
@@ -541,7 +554,7 @@ struct LogAddExp {
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
};
}
};
struct Multiply {
@@ -589,14 +602,14 @@ struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
}
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
}
};
struct Select {
@@ -610,35 +623,35 @@ struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
}
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
}
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
}
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
}
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
}
};
} // namespace mlx::core::detail

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -8,9 +8,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
@@ -113,61 +113,6 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -203,9 +148,15 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
@@ -232,81 +183,6 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -437,20 +313,6 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, detail::LogicalNot());
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -536,63 +398,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -600,7 +405,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -663,49 +478,6 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
@@ -716,7 +488,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {
@@ -737,18 +510,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
@@ -786,58 +547,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -854,11 +563,6 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
}
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -883,38 +587,36 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General);
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
}
void _qmm_dispatch(
array out,
array& out,
const array& x,
const array& w,
const array& scales,
@@ -253,6 +253,81 @@ void _qmm_dispatch(
}
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
} // namespace mlx::core

View File

@@ -104,48 +104,14 @@ void reduce_dispatch_out(
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
switch (out.dtype()) {
case bool_:
reduction_op<InT, bool>(in, out, axes, false, op);
break;
case uint8:
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
break;
case uint16:
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
break;
case uint32:
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
break;
case uint64:
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
break;
case int8:
reduction_op<InT, int8_t>(in, out, axes, 0, op);
break;
case int16:
reduction_op<InT, int16_t>(in, out, axes, 0, op);
break;
case int32:
reduction_op<InT, int32_t>(in, out, axes, 0, op);
break;
case int64:
reduction_op<InT, int64_t>(in, out, axes, 0, op);
break;
case float16:
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
break;
case float32:
reduction_op<InT, float>(in, out, axes, 0.0f, op);
break;
case bfloat16:
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
break;
case complex64:
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
break;
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
} break;
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op);
@@ -168,6 +134,29 @@ void reduce_dispatch_out(
} // namespace
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -49,47 +49,18 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
// Helper for the ndimensional strided loop
// Should this be in utils?
inline void nd_loop(
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
const std::vector<size_t>& strides);
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
const std::vector<int>& axes);
template <typename T, typename U, typename Op>
struct DefaultStridedReduce {
@@ -123,102 +94,6 @@ struct DefaultContiguousReduce {
}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
break;
}
size *= x.shape(i);
}
if (size >= strides.back()) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
template <typename T, typename U, typename OpS, typename OpC, typename Op>
void reduction_op(
const array& x,
@@ -361,6 +236,4 @@ void reduction_op(
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
}
} // namespace
} // namespace mlx::core

View File

@@ -0,0 +1,118 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/reduce.h"
namespace mlx::core {
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
break;
}
size *= x.shape(i);
}
if (size >= strides.back()) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

View File

@@ -234,7 +234,7 @@ void scan_dispatch(
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
: std::numeric_limits<U>::min();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);

View File

@@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
} // namespace mlx::core

View File

@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
// Perform sorting in place
for (int i = 0; i < n_rows; i++) {
@@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -3,7 +3,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -145,12 +144,4 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}
} // namespace mlx::core

View File

@@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename stride_t>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
return strides;
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.

View File

@@ -1,33 +1,140 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
function(make_jit_source SRC_FILE)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}()
#
# To use the function, declare it in jit/includes.h and
# include jit/includes.h.
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
)
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
make_jit_source(
utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(hadamard)
add_dependencies(mlx compiled_preamble)
if (MLX_METAL_JIT)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(
reduce
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h
kernels/steel/defines.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/transforms.h
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized)
make_jit_source(gemv_masked)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
@@ -37,10 +144,13 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -242,8 +242,17 @@ void MetalAllocator::free(Buffer buffer) {
}
MetalAllocator& allocator() {
static MetalAllocator allocator_;
return allocator_;
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
size_t set_cache_limit(size_t limit) {

View File

@@ -0,0 +1,296 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
}
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << ndim;
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel =
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
binary_op_gpu_inplace(inputs, out, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU_MULTI(DivMod)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,33 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s);
} // namespace mlx::core

View File

@@ -4,8 +4,8 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
@@ -56,12 +56,15 @@ inline void build_kernel(
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
}
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
@@ -110,13 +113,17 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (auto& x : inputs) {
int nc_in_count = 0;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ";" << std::endl;
os << ");" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
@@ -124,17 +131,20 @@ inline void build_kernel(
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
}
}
@@ -190,7 +200,8 @@ void Compiled::eval_gpu(
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -295,6 +306,7 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
@@ -302,13 +314,17 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
in_strides.insert(
in_strides.end(),
strides[stride_idx].begin(),
strides[stride_idx].end());
stride_idx++;
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
@@ -336,7 +352,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -347,7 +363,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -1,9 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
@@ -59,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -137,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
@@ -247,7 +248,7 @@ void slow_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -257,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
const int implicit_N = O_per_group;
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
// Determine block and warp tiles
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
@@ -281,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int channel_k_iters = ((C_per_group + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
if (C_per_group <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
n_channel_specialization = C_per_group;
} else if (C_per_group <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
n_channel_specialization = C_per_group;
}
bool small_filter = (!n_channel_specialization) &&
@@ -331,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
out,
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -340,7 +355,7 @@ void implicit_gemm_conv_2D_gpu(
size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
// Encode arrays
compute_encoder.set_input_array(in, 0);
@@ -352,7 +367,7 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -484,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -512,7 +528,7 @@ void implicit_gemm_conv_2D_general_gpu(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -613,7 +629,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -641,7 +657,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -689,7 +705,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
@@ -703,6 +719,7 @@ void conv_2D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
const int groups,
bool flip,
std::vector<array>& copies) {
// Make conv params
@@ -718,12 +735,12 @@ void conv_2D_gpu(
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
@@ -735,6 +752,18 @@ void conv_2D_gpu(
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
@@ -759,6 +788,56 @@ void conv_2D_gpu(
}
}
void conv_3D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
/* const int idil[NDIM] = */
{in_dilation[0], in_dilation[1], in_dilation[2]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0],
in.strides()[1],
in.strides()[2],
in.strides()[3],
in.strides()[4]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0],
wt.strides()[1],
wt.strides()[2],
wt.strides()[3],
wt.strides()[4]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0],
out.strides()[1],
out.strides()[2],
out.strides()[3],
out.strides()[4]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -783,8 +862,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// 3D conv
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 2D conv
if (out.ndim() == 4) {
else if (out.ndim() == 4) {
conv_2D_gpu(
s,
d,
@@ -795,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
copies);
}

View File

@@ -4,12 +4,14 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
@@ -31,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (out.size() == 0) {
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
@@ -55,34 +54,46 @@ void copy_gpu_inplace(
int64_t out_offset,
CopyType ctype,
const Stream& s) {
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "scopy";
break;
case CopyType::Vector:
kname << "vcopy";
break;
case CopyType::General:
kname << "gcopy";
break;
case CopyType::GeneralGeneral:
kname << "ggcopy";
break;
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
@@ -106,7 +117,7 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}
@@ -126,16 +137,17 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -14,7 +14,6 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -25,9 +24,34 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -37,7 +61,6 @@ auto load_device() {
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
@@ -116,6 +139,76 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@@ -130,9 +223,6 @@ Device::~Device() {
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
(*e.second)->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -169,27 +259,26 @@ void Device::increment_command_buffer_ops(int index) {
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
return bit->second.second;
}
void Device::commit_command_buffer(int index) {
@@ -200,25 +289,15 @@ void Device::commit_command_buffer(int index) {
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
(*eit->second)->endEncoding();
(*eit->second)->release();
encoder_map_.erase(eit);
}
encoder_map_.erase(index);
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
}
return *(eit->second);
}
@@ -232,13 +311,9 @@ void Device::register_library(
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
void Device::register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
@@ -248,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}
@@ -262,7 +337,11 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();
// Throw error if unable to compile library
if (!mtl_lib) {
@@ -344,7 +423,6 @@ MTL::Function* Device::get_function_(
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
@@ -513,11 +591,13 @@ MTL::ComputePipelineState* Device::get_kernel(
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}

View File

@@ -9,36 +9,16 @@
#include <unordered_map>
#include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false) {};
CommandEncoder(MTL::CommandBuffer* cbuf);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -61,41 +41,24 @@ struct CommandEncoder {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder();
private:
void maybe_split();
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
@@ -112,7 +75,6 @@ class Device {
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
@@ -123,10 +85,8 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path);
void register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
void register_library(const std::string& lib_name);
MTL::Library* get_library(const std::string& name);

View File

@@ -1,106 +1,803 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <complex>
#include <map>
#include <numeric>
#include <set>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
auto& in = inputs[0];
#define MAX_STOCKHAM_FFT_SIZE 4096
#define MAX_RADER_FFT_SIZE 2048
#define MAX_BLUESTEIN_FFT_SIZE 2048
// Threadgroup memory batching improves throughput for small n
#define MIN_THREADGROUP_MEM_SIZE 256
// For strided reads/writes, coalesce at least this many complex64s
#define MIN_COALESCE_WIDTH 4
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
inline const std::vector<int> supported_radices() {
// Ordered by preference in decomposition.
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
}
std::vector<int> prime_factors(int n) {
int z = 2;
std::vector<int> factors;
while (z * z <= n) {
if (n % z == 0) {
factors.push_back(z);
n /= z;
} else {
z++;
}
}
if (n > 1) {
factors.push_back(n);
}
return factors;
}
struct FourStepParams {
bool required = false;
bool first_step = true;
int n1 = 0;
int n2 = 0;
};
// Forward Declaration
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s);
struct FFTPlan {
int n = 0;
// Number of steps for each radix in the Stockham decomposition
std::vector<int> stockham;
// Number of steps for each radix in the Rader decomposition
std::vector<int> rader;
// Rader factor, 1 if no rader factors
int rader_n = 1;
int bluestein_n = -1;
// Four step FFT
bool four_step = false;
int n1 = 0;
int n2 = 0;
};
int next_fast_n(int n) {
return next_power_of_2(n);
}
std::vector<int> plan_stockham_fft(int n) {
auto radices = supported_radices();
std::vector<int> plan(radices.size(), 0);
int orig_n = n;
if (n == 1) {
return plan;
}
for (int i = 0; i < radices.size(); i++) {
int radix = radices[i];
// Manually tuned radices for powers of 2
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
continue;
}
while (n % radix == 0) {
plan[i] += 1;
n /= radix;
if (n == 1) {
return plan;
}
}
}
throw std::runtime_error("Unplannable");
}
FFTPlan plan_fft(int n) {
auto radices = supported_radices();
std::set<int> radices_set(radices.begin(), radices.end());
FFTPlan plan;
plan.n = n;
plan.rader = std::vector<int>(radices.size(), 0);
auto factors = prime_factors(n);
int remaining_n = n;
// Four Step FFT when N is too large for shared mem.
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
// For power's of two we have a fast, no transpose four step implementation.
plan.four_step = true;
// Rough heuristic for choosing faster powers of two when we can
plan.n2 = n > 65536 ? 1024 : 64;
plan.n1 = n / plan.n2;
return plan;
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
// Otherwise we use a multi-upload Bluestein's
plan.four_step = true;
plan.bluestein_n = next_fast_n(2 * n - 1);
return plan;
}
size_t n = in.shape(axes_[0]);
for (int factor : factors) {
// Make sure the factor is a supported radix
if (radices_set.find(factor) == radices_set.end()) {
// We only support a single Rader factor currently
// TODO(alexbarron) investigate weirdness with large
// Rader sizes -- possibly a compiler issue?
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
if (radices_set.find(rf) == radices_set.end()) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
}
plan.rader = plan_stockham_fft(factor - 1);
plan.rader_n = factor;
remaining_n /= factor;
}
}
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
plan.stockham = plan_stockham_fft(remaining_n);
return plan;
}
int compute_elems_per_thread(FFTPlan plan) {
// Heuristics for selecting an efficient number
// of threads to use for a particular mixed-radix FFT.
auto n = plan.n;
std::vector<int> steps;
auto radices = supported_radices();
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
std::set<int> used_radices;
for (int i = 0; i < steps.size(); i++) {
int radix = radices[i % radices.size()];
if (steps[i] > 0) {
used_radices.insert(radix);
}
}
// Manual tuning for 7/11/13
if (used_radices.find(7) != used_radices.end() &&
(used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end())) {
return 7;
} else if (
used_radices.find(11) != used_radices.end() &&
used_radices.find(13) != used_radices.end()) {
return 11;
}
// TODO(alexbarron) Some really weird stuff is going on
// for certain `elems_per_thread` on large composite n.
// Possibly a compiler issue?
if (n == 3159)
return 13;
if (n == 3645)
return 5;
if (n == 3969)
return 7;
if (n == 1982)
return 5;
if (used_radices.size() == 1) {
return *(used_radices.begin());
}
if (used_radices.size() == 2) {
if (used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end()) {
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
}
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// In all other cases use the second smallest radix.
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// Rader
int mod_exp(int x, int y, int n) {
int out = 1;
while (y) {
if (y & 1) {
out = out * x % n;
}
y >>= 1;
x = x * x % n;
}
return out;
}
int primitive_root(int n) {
auto factors = prime_factors(n - 1);
for (int r = 2; r < n - 1; r++) {
bool found = true;
for (int factor : factors) {
if (mod_exp(r, (n - 1) / factor, n) == 1) {
found = false;
break;
}
}
if (found) {
return r;
}
}
return -1;
}
std::tuple<array, array, array> compute_raders_constants(
int rader_n,
const Stream& s) {
int proot = primitive_root(rader_n);
// Fermat's little theorem
int inv = mod_exp(proot, rader_n - 2, rader_n);
std::vector<short> g_q(rader_n - 1);
std::vector<short> g_minus_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
g_q[i] = mod_exp(proot, i, rader_n);
g_minus_q[i] = mod_exp(inv, i, rader_n);
}
array g_q_arr(g_q.begin(), {rader_n - 1});
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
std::vector<std::complex<float>> b_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
b_q[i] = std::exp(std::complex<float>(0, pi_i));
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
size_t fft_size = rader_n - 1;
// This FFT is always small (<4096, batch 1) so save some overhead
// and do it on the CPU
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ b_q.data(),
/* data_out= */ b_q_fft_ptr,
/* scale= */ 1.0f);
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
}
// Bluestein
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// We need to calculate the Bluestein twiddle factors
// in double precision for the overall numerical stability
// of Bluestein's FFT algorithm to be acceptable.
//
// Metal doesn't support float64, so instead we
// manually implement the required operations on cpu.
//
// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
for (int i = -n + 1; i < n; i++) {
double theta = pow(i, 2) * M_PI / (double)n;
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
if (i >= 0) {
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
}
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
std::ptrdiff_t item_size = w_q.itemsize();
size_t fft_size = bluestein_n;
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ w_q_vec.data(),
/* data_out= */ w_q_ptr,
/* scale= */ 1.0f);
return std::make_tuple(w_k, w_q);
}
void multi_upload_bluestein_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
array temp1(temp_shape, complex64, nullptr, {});
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
} else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s);
} else {
temp.copy_shared_buffer(in);
}
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
pad_temp,
pad_temp1,
axis,
/*inverse=*/false,
/*real=*/false,
FourStepParams(),
/*inplace=*/false,
s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
pad_temp,
pad_temp1,
axis,
/* inverse= */ true,
/* real= */ false,
FourStepParams(),
/*inplace=*/true,
s);
int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
fft_op(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
}
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s) {
auto& d = metal::device(s.device);
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
if (n == 1) {
out.copy_shared_buffer(in);
return;
}
if (four_step_params.required) {
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) {
auto check_input = [&axis, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
bool no_copy = x.strides()[axis] == 1 &&
(x.flags().row_contiguous || x.flags().col_contiguous);
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
size_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) {
if (a == axis) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(axis);
cur_stride *= x.shape(a);
}
}
auto flags = x.flags();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(x.shape(), strides);
flags.col_contiguous = is_row_contiguous;
flags.row_contiguous = is_col_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
const array& in_contiguous = check_input(in);
// real to complex: n -> (n/2)+1
// complex to real: (n/2)+1 -> n
auto out_strides = in_contiguous.strides();
size_t out_data_size = in_contiguous.data_size();
if (in.shape(axis) != out.shape(axis)) {
for (int i = 0; i < out_strides.size(); i++) {
if (out_strides[i] != 1) {
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
}
}
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
}
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
// TODO: allow donation here
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
if (!inplace) {
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());
}
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto radices = supported_radices();
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
// Setup function constants
bool power_of_2 = is_power_of_2(fft_size);
auto make_int = [](int* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
};
auto make_bool = [](bool* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
};
std::vector<MTLFC> func_consts = {
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
// Start of radix/rader step constants
int index = 4;
for (int i = 0; i < plan.stockham.size(); i++) {
func_consts.push_back(make_int(&plan.stockham[i], index));
index += 1;
}
for (int i = 0; i < plan.rader.size(); i++) {
func_consts.push_back(make_int(&plan.rader[i], index));
index += 1;
}
int elems_per_thread = compute_elems_per_thread(plan);
func_consts.push_back(make_int(&elems_per_thread, 2));
int rader_m = n / plan.rader_n;
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
int total_batch_size = size / n;
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
// We batch among threadgroups for improved efficiency when n is small
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
if (four_step_params.required) {
// Require a threadgroup batch size of at least 4 for four step FFT
// so we can coalesce the memory accesses.
threadgroup_batch_size =
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
}
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
// FFTs up to 2^20 are currently supported
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
// ceil divide
int batch_size =
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
if (real && !four_step_params.required) {
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
// Only required by four step
int step = -1;
{
std::ostringstream kname;
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
std::string inv_string = inverse ? "true" : "false";
std::string real_string = real ? "true" : "false";
std::string func_name;
if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str;
func_name = "bluestein_fft";
} else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str;
func_name = "rader_fft";
} else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string;
func_name = "four_step_fft";
} else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str;
func_name = "fft";
}
std::string base_name = kname.str();
// We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str();
auto template_def = func_name == "four_step_fft" ? get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real)
: get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str);
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
if (plan.bluestein_n > 0) {
// Precomputed twiddle factors for Bluestein's
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_q);
copies.push_back(w_k);
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
copies.push_back(g_q);
copies.push_back(g_minus_q);
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
bool inplace,
const Stream& s) {
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
}
void nd_fft_op(
const array& in,
array& out,
const std::vector<size_t>& axes,
bool inverse,
bool real,
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
bool inplace = reverse_index >= 3 && i != 0;
// Opposite order for fft vs ifft
int index = inverse ? reverse_index : i;
size_t axis = axes[index];
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
if (axes_.size() > 1) {
nd_fft_op(in, out, axes_, inverse_, real_, s);
} else {
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,203 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
// using the hadamard matrices above
//
// e.g. m = 2
// METAL_FUNC void hadamard_m(thread float *x) {
// float tmp[2];
// tmp[0] = + x[0] + x[1];
// tmp[1] = + x[0] - x[1];
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
// }
//
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
std::ostringstream source;
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
if (m == 1) {
source << "}" << std::endl;
return source.str();
}
source << " float tmp[" << m << "];" << std::endl;
auto start = 1;
auto end = matrix.find('\n', start);
int index = 0;
while (end != std::string_view::npos) {
source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
source << " " << row[i] << " x[" << i << "]";
}
source << ";" << std::endl;
start = end + 1;
end = matrix.find('\n', start);
index++;
}
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
<< std::endl;
source << "}" << std::endl;
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
int batch_size = in.size() / n;
int threads_per = n / max_radix;
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -1,24 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include <fmt/format.h>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
std::pair<std::string, std::string> make_index_args(
const std::string& idx_type,
int nidx) {
std::ostringstream idx_args;
std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i);
if (i < nidx - 1) {
idx_args << "\n";
idx_arr << ",";
}
}
return {idx_args.str(), idx_arr.str()};
}
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
idx_type_str,
nidx,
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
@@ -102,12 +139,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
op_name = "none";
break;
case Scatter::Sum:
kname << "_sum";
op_name = "sum";
break;
case Scatter::Prod:
kname << "_prod";
op_name = "prod";
break;
case Scatter::Max:
kname << "_max";
op_name = "max";
break;
case Scatter::Min:
kname << "_min";
op_name = "min";
break;
}
kname << "_" << nidx;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
std::string op_type;
switch (reduce_type_) {
case Scatter::None:
op_type = "None";
break;
case Scatter::Sum:
op_type = "Sum<{0}>";
break;
case Scatter::Prod:
op_type = "Prod<{0}>";
break;
case Scatter::Max:
op_type = "Max<{0}>";
break;
case Scatter::Min:
op_type = "Min<{0}>";
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
idx_type_str,
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
@@ -206,17 +293,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
@@ -279,14 +377,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -0,0 +1,9 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,25 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view gemv_masked_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
const device {itype}* mat [[buffer(0)]],
const device {itype}* in_vec [[buffer(1)]],
device {itype}* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -0,0 +1,38 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* arange();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* hadamard();
const char* quantized();
const char* ternary();
const char* scan();
const char* softmax();
const char* sort();
const char* reduce();
const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
const char* gemv_masked();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,93 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
src_strides,
src_ndim,
slice_sizes,
axes,
idxs,
index,
grid_dim);
}}
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}}
[[kernel]] void scatter{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
updates,
out,
upd_shape,
upd_strides,
upd_ndim,
upd_size,
out_shape,
out_strides,
out_ndim,
axes,
idxs,
gid);
}}
)";

View File

@@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,26 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_conv_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";
constexpr std::string_view steel_conv_general_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -0,0 +1,106 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_gemm_fused_kernels = R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
const device {itype} *A [[buffer(0)]],
const device {itype} *B [[buffer(1)]],
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
device {itype} *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_masked_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
block_masked_gemm<
{itype},
{outmasktype},
{opmasktype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk<
{itype},
{otype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {otype}* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum_axpby<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device {otype}* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,642 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << u2_def << g_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def;
}
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def;
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::sort();
for (bool is_argsort : {true, false}) {
std::string bool_string = is_argsort ? "true" : "false";
std::string func_string = is_argsort ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
"true",
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More