Compare commits

...

277 Commits

Author SHA1 Message Date
Alex Barron
82a956c1d9 fix test 2024-12-06 10:26:54 -08:00
Alex Barron
769704653a cpu fallback 2024-12-06 01:22:50 -08:00
Alex Barron
c89ddf62b4 add checks 2024-12-06 01:09:00 -08:00
Alex Barron
3507c104a5 add test 2024-12-06 00:45:01 -08:00
Alex Barron
12a4d89a7c working qsdpa 2024-12-06 00:21:05 -08:00
Awni Hannun
e047fd977d compile changes if stream changes (#1644) 2024-12-03 14:37:44 -08:00
Jagrit Digani
9d40e521d7 Stop matrix copies with new attention kernel (#1639) 2024-12-02 14:12:38 -08:00
Alex Barron
1445dcaa60 let class predicate specify quantization parameters (#1638) 2024-12-02 14:09:28 -08:00
Jesper Stemann Andersen
e4eeb4e910 Added missing unordered_map includes (#1635)
* Added missing includes in mlx/io.h and mlx/backend/metal/metal.h

* Added additional missing unordered_map includes that fixes build on FreeBSD
2024-12-02 07:03:03 -08:00
Awni Hannun
aa86876813 fix transformer decoder post norm LN (#1637) 2024-12-02 07:02:17 -08:00
Jesper Stemann Andersen
974bb54ab2 CMake: Enabled using Accelerate on x86_64 / x64 (#1625)
* CMake: Enabled using Accelerate on x86_64 / x64

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761

* CMake: Removed superfluous MLX_BUILD_ARM
2024-11-28 10:55:45 -08:00
Ikko Eltociear Ashimine
9bc2183a31 docs: update device.cpp (#1632)
unecessary -> unnecessary
2024-11-27 20:58:26 -08:00
Awni Hannun
d4b222b6d3 Fix some leaks and races (#1629)
* fix leak and fix potential race

* more leak fixes

* fix one more
2024-11-27 20:01:20 -08:00
Jesper Stemann Andersen
af2af818a6 Enables build for *-linux-musl (#1627)
Also contributes to being able to build for *-w64-mingw32.

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:24 -08:00
Jesper Stemann Andersen
698e63a608 CMake: Build with dlfcn-win32 to have dlopen etc. on win32 (#1628)
Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:13 -08:00
Awni Hannun
211411faf2 fix large ops (#1620) 2024-11-24 09:17:10 -08:00
Awni Hannun
bb303c45a5 version (#1617) 2024-11-22 12:00:03 -08:00
Alex Barron
6f7986d592 Cleaner qmv/qvm (#1616) 2024-11-22 11:14:08 -08:00
Awni Hannun
7cbb4aef17 Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Jagrit Digani
02bec0bb6d Matrix Attention kernel (#1610)
* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c 3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b Reduce specializations (#1607)
* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
2024-11-21 19:53:00 -08:00
Awni Hannun
dcca0d7477 contiguous op / prim (#1612) 2024-11-21 19:51:49 -08:00
Cocoa
0d5e7716ad fix typo: accross -> across (#1609)
Signed-off-by: Cocoa <i@uwucocoa.moe>
2024-11-20 15:30:51 -08:00
Angelos Katharopoulos
d8c824c594 Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00
Awni Hannun
61d787726a Fix view scalar bug segfault (#1603)
* fix view scalar bug

* fix view scalar bug

* one more fix
2024-11-19 10:54:05 -08:00
Angelos Katharopoulos
5e89aace9b Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
2af7e8a9a6 fix cmake version (#1601) 2024-11-19 08:45:05 -08:00
Awni Hannun
2419edd5b2 Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
2024-11-18 19:52:00 -08:00
Awni Hannun
bf481e8e5d Fix sibling leak (#1590)
* add test

* fix + test

* fix fix
2024-11-18 19:17:01 -08:00
Awni Hannun
9d7fa6b8e6 Use osx deployment target to pick Metal version (#1595)
* choose metal based on deployment target rather than system version

* nit

* unused compile def
2024-11-18 19:16:49 -08:00
Angelos Katharopoulos
073076ac7d 2-Pass Sdpa Inference Kernel (#1597) 2024-11-18 17:31:53 -08:00
Awni Hannun
9bd03dd9b4 More buffer donation with no-ops (#1591)
* more donation

* fix test

* fix build
2024-11-18 08:35:41 -08:00
Awni Hannun
6931f84412 fix dispatch threads for a few kernels (#1594) 2024-11-18 08:35:25 -08:00
xnorai
16ec0556a0 Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)
* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB

* Set the upper size limit for the header to 100K as in Rust safetensors
2024-11-18 07:22:51 -08:00
Awni Hannun
610af352d4 Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT

* fix extension

* fix extension build

* fix extension build

* Update utils.h
2024-11-15 16:54:36 -08:00
Awni Hannun
b35f1e3c9c fix donation in sdpa (#1587) 2024-11-13 17:21:13 -08:00
Awni Hannun
dfa0b9aab4 Cpu fast quantize (#1578)
* cpu quantize

* fix
2024-11-08 20:10:39 -08:00
Alex Barron
a4c47b0276 OOB QMV fix (#1579)
* fix oob access in qmv

* skip more

* fix small case
2024-11-08 17:59:45 -08:00
Alex Barron
111fefd5e9 Fix OOB access in qmv (#1577)
* fix oob access in qmv

* skip more
2024-11-08 15:41:30 -08:00
Awni Hannun
c1fe1ef081 Bfs width limit (#1568)
* width limit

* fix

* large limit

* put env vars in env namespace
2024-11-08 15:00:46 -08:00
Awni Hannun
8c34c9dac4 throw for invalid case and remove test (#1575) 2024-11-08 12:04:03 -08:00
Awni Hannun
91c0277356 fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Awni Hannun
9f0d5c12fc Fully wrap the command encoder (#1572)
* fully wrap the command encoder

* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Awni Hannun
59247c2b62 add groups in conv2d (#1569) 2024-11-07 13:57:53 -08:00
Awni Hannun
9a3842a2d9 fix (#1566) 2024-11-06 17:10:33 -08:00
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
Awni Hannun
76f275b4df error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Awni Hannun
f1951d6cce Use fewer barriers (#1561)
* use fewer barriers

* comment
2024-11-04 10:26:49 -08:00
Angelos Katharopoulos
62f297b51d Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Chris Offner
46d8b16ab4 Fix vmap example in docs (#1556) 2024-11-02 17:44:14 -07:00
Chris Offner
42533931fa Fix typo "it's" -> "its" (#1555) 2024-11-02 06:06:34 -07:00
Awni Hannun
9bd3a7102f add python 3.13 to circle (#1553) 2024-11-01 20:55:35 -07:00
Alex Barron
9e516b71ea Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
eac961ddb1 patch (#1550) 2024-10-31 16:10:14 -07:00
Awni Hannun
57c6aa7188 fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
Awni Hannun
cde5b4ad80 patch (#1546) 2024-10-30 19:31:22 -07:00
Awni Hannun
4f72c66911 improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Jagrit Digani
960e3f0f05 Gemm update (#1518) 2024-10-30 19:30:28 -07:00
Awni Hannun
884af42da2 Fix thread group for large arrays (#1543)
* fix thread group for large arrays

* comment

* one more
2024-10-30 16:25:12 -07:00
Alex Barron
048fabdabd Fix vmap constant output size (#1524)
* use inputs to determine output size

* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Léo
917252a5a1 Add favicon to docs (#1545)
* add sphinx's html_favicon config

* removed unneeded newline

* ran pre-commit hooks
2024-10-30 13:54:13 -07:00
Carlo Cabrera
1a992e31e8 Skip using Residency sets in VMs (#1537)
* Skip using Residency sets in VMs

Attempting to use residency sets in a VM throws[^1]

    libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to construct residency set.

Not quite sure if this is the best fix, but it does make the error go
away.

Note that it was previously possible to run simple programs that used
mlx in a VM prior to 0eb56d5be0. See
related discussion at Homebrew/homebrew-core#195627.

[^1]: https://github.com/Homebrew/homebrew-core/actions/runs/11525831492/job/32105148462#step:3:56

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* change residency check

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-29 19:37:23 -07:00
Awni Hannun
d2ff04a4f2 fix format (#1539) 2024-10-28 18:29:14 -07:00
Awni Hannun
015c247393 change wino dispatch conditoin (#1534) 2024-10-28 11:13:44 -07:00
Awni Hannun
d3cd26820e Faster bits and bernoulli (#1535)
* faster bits and bernoulli

* fix bernoulli
2024-10-28 11:11:00 -07:00
Awni Hannun
91f6c499d7 fix (#1529) 2024-10-25 19:25:35 -07:00
Awni Hannun
35e9c87ab9 patch bump (#1528) 2024-10-25 13:13:23 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
Awni Hannun
c21331d47f version bump (#1477) 2024-10-10 13:05:17 -07:00
Awni Hannun
e1c9600da3 Add mx.random.permutation (#1471)
* random permutation

* comment
2024-10-08 19:42:19 -07:00
Awni Hannun
1fa0d20a30 consistently handle all -inf in softmax (#1470) 2024-10-08 09:54:02 -07:00
Awni Hannun
3274c6a087 Fix array is_available race cases (#1468) 2024-10-07 19:13:50 -07:00
Angelos Katharopoulos
9b12093739 Add the roll op (#1455) 2024-10-07 17:21:42 -07:00
Awni Hannun
f374b6ca4d Bump nanobind to 2.2 (#1461)
* bump nanobind

* extension version for tests
2024-10-07 16:52:40 -07:00
Awni Hannun
0070e1db40 Fix deep recursion with siblings (#1462)
* fix recursion with siblings

* fix

* add test

* increase tol
2024-10-07 06:15:33 -07:00
Awni Hannun
95d04805b3 Fix complex power on Metal (#1460) 2024-10-06 19:58:30 -07:00
Awni Hannun
e4534dac17 Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups

* fix

* fix

* fix + test
2024-10-06 07:08:53 -07:00
Angelos Katharopoulos
fef3c4ec1d Fix mpi test in CI (#1456)
* Fix mpi test in CI

* Set bind to none
2024-10-06 06:09:17 -07:00
Awni Hannun
1bdc038bf9 fix argpartition + faster {arg} sorts / partitions (#1453) 2024-10-03 14:21:25 -07:00
Awni Hannun
5523d9c426 faster cpu indexing (#1450) 2024-10-03 13:53:47 -07:00
Angelos Katharopoulos
d878015228 Fix normalization check_input (#1452) 2024-10-03 13:26:56 -07:00
Cheng
5900e3249f Fix building on Linux (#1446) 2024-09-30 07:00:39 -07:00
Angelos Katharopoulos
bacced53d3 Fix row reduce with very few rows (#1447) 2024-09-29 20:00:35 -07:00
Lucas Newman
4a64d4bff1 Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 06:41:07 -07:00
Awni Hannun
b1e2b53c2d bump (#1445) 2024-09-27 13:53:02 -07:00
Awni Hannun
11354d5bff Avoid io timeout for large arrays (#1442) 2024-09-27 13:32:14 -07:00
Awni Hannun
718aea3f1d allow take to work with integer index (#1440) 2024-09-26 15:58:03 -07:00
Awni Hannun
5b6f38df2b Faster cpu ops (#1434)
* faster binary and cleaner copy

* use recursive template for other ops

* more cleanup

* fix from cleanup

* more clean

* fix binary

* use contiguous iterator

* add 3d

* nits

* fix

* fix?

* fix

* fix rebase
2024-09-26 09:19:13 -07:00
Awni Hannun
0b4a58699e Some overhead reductions in mx.fast.metal_kernel (#1437)
* some overhead reductions

* fix

* use +=

* use more +=
2024-09-25 17:25:21 -07:00
Awni Hannun
4f9f9ebb6f Faster Metal unary and binary for general case (#1431)
* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
2024-09-25 12:07:43 -07:00
Awni Hannun
afc9c0ec1b dtype is copy assignable (#1436) 2024-09-25 12:07:13 -07:00
Awni Hannun
195b429d99 Put along axis + fixe for partition grad (#1430)
* put along axis, fixes for partition grad

* zeros for arg reduce
2024-09-23 10:03:38 -07:00
Luke Carlson
2b878e9dd7 Create CITATION.cff (#1425) 2024-09-20 11:39:46 -07:00
Awni Hannun
67b6bf530d Optimization for general ND copies (#1421) 2024-09-17 17:59:51 -07:00
Nripesh Niketan
6af5ca35b2 feat: add cross_product (#1252)
* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-17 13:12:43 -07:00
Awni Hannun
4f46e9c997 More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
2024-09-17 12:46:31 -07:00
Awni Hannun
c6739ba7f3 Faster RNN layers (#1419)
* faster rnn

* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9 Data parallel helper (#1407) 2024-09-16 18:17:21 -07:00
jjuang-apple
8d68a3e805 remove fmt dependencies from MLX install (#1417) 2024-09-16 13:32:28 -07:00
jjuang-apple
6bbcc453ef avoid using find_library to make install truly portable (#1416) 2024-09-16 13:21:32 -07:00
Awni Hannun
d5ed4d7a71 override class function (#1418) 2024-09-16 13:21:04 -07:00
Nripesh Niketan
669c27140d Chore: add pre-commit hook for cmake (#1362)
* reset and lint

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-16 12:53:01 -07:00
Max-Heinrich Laves
adcc88e208 Conv cpu improvements (#1410) 2024-09-15 18:45:10 -07:00
Awni Hannun
d6492b0163 fix clip (#1415) 2024-09-14 16:09:09 -07:00
Awni Hannun
b3f52c9fbe ensure io/comm streams are active before eval (#1412) 2024-09-14 06:17:36 -07:00
c0g
bd8396fad8 Fix typo in transformer docs (#1414) 2024-09-14 06:05:15 -07:00
Angelos Katharopoulos
d0c58841d1 Patch bump (#1408) 2024-09-12 16:44:23 -07:00
Angelos Katharopoulos
881f09b2e2 Allow querying the allocator for the buffer size (#1404) 2024-09-11 21:02:16 -07:00
Awni Hannun
8b30acd7eb fix module attribute set, reset, set (#1403) 2024-09-11 16:30:42 -07:00
Awni Hannun
02efb310ca Xcode 160 (#1384)
* xcode 16.0 with debug tests

* limit nproc for builds

* vmap bug

* assert bug

* run python tests in debug mode

* fix view, bool copies preserve bits'

* actual view fix
2024-09-10 15:15:17 -07:00
Awni Hannun
e7e59c6f05 Fix copying scalars by adding fill_gpu (#1402)
* fix copying scalars by adding fill_gpu

* Another copy scalar changed to fill

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-09-09 15:54:08 -07:00
Awni Hannun
3ae6aabe9f throw for certain cases of non captured inputs in compile (#1401) 2024-09-09 14:54:31 -07:00
xnorai
dc627dcb5e Replace the use of result_of_t with invoke_result_t (#1397)
* Fix C++20 incompatibility

* Fix C++20 incompatibility
2024-09-06 19:52:57 -07:00
Max-Heinrich Laves
efeb9c0f02 Transposed Convolution (#1245)
* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Awni Hannun
ba3e913c7a Simplifications for MLX C (#1396)
* simplifications for MLX C

* use vectors instead of map

* update examples
2024-09-06 19:16:50 -07:00
Awni Hannun
7cca1727af Fix slice data size (#1394)
* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
2024-09-04 19:10:43 -07:00
Bhargav Yagnik
11371fe251 Test to prevent bugs like #1386 (#1391)
* updated test_array for missing ops

* formatting changes
2024-09-04 17:24:30 -07:00
Awni Hannun
41c603d48a fix jit reduce (#1395) 2024-09-04 14:03:10 -07:00
Angelos Katharopoulos
969337345f Fix reduce edge case (#1389) 2024-09-01 21:37:51 -07:00
Awni Hannun
9592766939 add std as method (#1387)
* add std as method

* add std as method
2024-09-01 19:49:16 -07:00
Angelos Katharopoulos
58dca7d846 Fix copy in the sort primitive (#1383) 2024-08-31 08:32:14 -07:00
Awni Hannun
0d302cd25b Fix compiel with byte sized constants (#1381) 2024-08-30 17:24:35 -07:00
Alex Barron
da691257ec Fix overflow in quantize/dequantize (#1379)
* add 2d indices to prevent overflow

* use nthreads not out size
2024-08-30 13:32:41 -07:00
Angelos Katharopoulos
1600092e92 Patch bump (#1376) 2024-08-29 16:54:30 -07:00
Awni Hannun
dba2bd1105 Even Even Faster IO (#1374)
* even more faster io

* make reader pool static

* make python reader thread safe

* one more optimization
2024-08-29 16:05:40 -07:00
Alex Barron
28be4de7c2 Fix JIT reductions (#1373) 2024-08-28 16:39:11 -07:00
Awni Hannun
a6c3b38fba Async load (#1372)
* async load

* async load
2024-08-28 14:21:55 -07:00
Awni Hannun
fcb65a3897 Even Faster I/O (#1369)
* try multithreading for faster IO

* smaller batch size

* Account for pread returning less than size

* nit

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-28 11:49:07 -07:00
Saanidhya
4e22a1dffe In continuation to PR1243 to solve issue #1240 (#1365)
* Solves issue #1240

* Correction

* Update python/mlx/utils.py

* Update python/mlx/utils.py

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-08-28 11:40:41 -07:00
Awni Hannun
291cf40aca Some fixes to typing (#1371)
* some fixes to typing

* fix module reference

* comment
2024-08-28 11:16:19 -07:00
Jeethu Rao
bd47e1f066 Fix neon_fast_exp and add more softmax tests (#1367) 2024-08-27 23:42:42 -07:00
Aditya Dhulipala
e6b223df5f Pinv (#875) 2024-08-27 23:06:12 -07:00
Angelos Katharopoulos
e64349bbdd Make eval just wait if all arrays are scheduled (#1368) 2024-08-27 17:01:22 -07:00
Angelos Katharopoulos
cdb59faea6 Adds send/recv ops in distributed (#1366) 2024-08-26 23:01:37 -07:00
Alex Barron
1d94ac3f90 Add optional headers to `mx.fast.metal_kernel` (#1358) 2024-08-26 21:45:45 -07:00
Awni Hannun
5f7d19d1f5 MPI ops in GPU stream for faster comms (#1356) 2024-08-26 15:12:50 -07:00
Awni Hannun
2fdf9eb535 Fix ternary for large arrays (#1359)
* fix ternary for large arrays

* fix
2024-08-26 11:22:27 -07:00
Awni Hannun
860d3a50d7 fix extension metal library finding (#1361) 2024-08-26 09:18:50 -07:00
Alex Barron
d1183821a7 int() and float() for mx.array (#1360) 2024-08-25 20:41:44 -07:00
Angelos Katharopoulos
8081df79be Fix boolean all reduce bug (#1355) 2024-08-24 10:09:32 -07:00
Nripesh Niketan
64bec4fad7 Chore: update pre-commit hooks (#1353)
* Chore: update pre-commit refs

* run pre-commit
2024-08-24 06:46:36 -07:00
Alex Barron
b96e105244 Add grid_sample example to metal_kernel docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel`

* add grid sample to docs

* zero_outputs -> init_value

* add missing header for linux
2024-08-23 18:24:16 -07:00
Awni Hannun
3b4d5484c7 Bump extension MLX version (#1350)
* Bump extension MLX version

* fix some docs nits
2024-08-23 12:38:34 -07:00
Alex Barron
684e11c664 patch (#1347) 2024-08-23 10:42:02 -07:00
Angelos Katharopoulos
b57a52813b Further reduction tuning (#1349)
* More reduction tuning
* Forgotten pdb
* Small column long row specialization
2024-08-23 10:35:25 -07:00
Alex Barron
da8deb2b62 fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-23 10:06:15 -07:00
Awni Hannun
98b6ce3460 Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-08-22 16:03:31 -07:00
Awni Hannun
f9e00efe31 fix nanobind and stub gen in circle (#1346) 2024-08-22 14:07:27 -07:00
Alex Barron
0fd2a1f4b0 Custom Metal Kernels from Python (#1325)
* start

* simple kernels working

* restructure

* inverse example working

* docs + fixes

* missing file

* fix imports

* address comments

* add docs + fix test

* Review comments + refactor to a single function

* update docs

* remove hashing

* fix contig bug in test

* back to a class

* trailing whitespace

* fix tests

* match c++ and python apis

* add link + make args kw_only
2024-08-22 13:46:29 -07:00
Awni Hannun
df3233454d 2d gather specialization (#1339) 2024-08-22 10:48:24 -07:00
Awni Hannun
82db84b899 bump nanobind + fix extension (#1344) 2024-08-21 16:05:07 -07:00
Awni Hannun
8ae751d3da fix io (#1343)
* fix io

* fix io

* comment
2024-08-21 13:14:46 -07:00
Awni Hannun
d40e76809f Fix rope (#1340)
* add test

* fix rope

* fix test
2024-08-20 17:37:52 -07:00
Awni Hannun
bb1b76d9dc RoPE with frequencies as optional input (#1337)
* start rope with freq input

* rope with frequencies

* nits

* fix bug

* fix bug + test

* cleanup

* optional base
2024-08-19 18:30:50 -07:00
Angelos Katharopoulos
9d26441224 Fix contiguity check (#1336)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-08-19 16:05:06 -07:00
Awni Hannun
f12f24a77c fix compiling with space in paths (#1332) 2024-08-15 16:39:24 -07:00
Awni Hannun
ae5b5cabfd Fix optimizer reloading from checkpoint (#1329)
* fix optimizer reloading from checkpoint

* comment
2024-08-15 07:33:23 -07:00
Awni Hannun
d0630ffe8c Read arrays from files faster (#1330)
* read faster

* faster write as well

* set default permission for linux

* comment
2024-08-14 20:09:56 -07:00
Alex Barron
99bb7d3a58 GPU mx.sign for complex64 (#1326) 2024-08-14 07:54:53 -07:00
Awni Hannun
63ae767232 fix transformer (#1327) 2024-08-13 16:04:26 -07:00
Awni Hannun
eaaea02010 Add isfinite (#1318)
* isfinite

* remove reduce test since fix is not complete
2024-08-13 14:49:28 -07:00
Bhargav Yagnik
a098bc92e0 Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
2024-08-13 11:54:21 -07:00
Awni Hannun
1086dc4db0 patch (#1320) 2024-08-12 16:13:33 -07:00
Brian Keene
19fb69e2ed Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
2024-08-12 12:57:09 -07:00
Awni Hannun
9231617eb3 Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317 CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95 Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
Angelos Katharopoulos
eb8819e91e Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08 Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25 Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9 Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1 add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501 fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
6c8dd307eb faster group norm (#1304) 2024-08-01 12:49:23 -07:00
Awni Hannun
43ffdab172 fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333 Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0 Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Atakan Tekparmak
6e06e3a904 feat: Added "tanh" option to GELU approximation (#1268) 2024-07-28 09:07:56 +02:00
Yaroslav
8cfb9fc0b8 Update requirements.txt (#1291) 2024-07-26 12:59:52 -07:00
Awni Hannun
7b456fd2c0 Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Awni Hannun
e9e53856d2 patch bump (#1287) 2024-07-25 11:42:09 -07:00
Anton Belov
5029894662 [Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42 Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50 Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7 Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae some fixes (#1281) 2024-07-23 11:49:05 -07:00
toji
6768c6a54a Adding missing type hints (#1243)
* added type hints for `run`, `tree_map` and `tree_map_with_path`

* fix lint

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Tim Gymnich
6307d166eb Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7 Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7 Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Awni Hannun
8c01a7893b minor fix in optimizer + docs (#1264) 2024-07-12 12:18:02 -07:00
Awni Hannun
218047c75a docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Alex Barron
d0da74209b version bump (#1260) 2024-07-11 11:17:55 -07:00
Angelos Katharopoulos
5c1fa64fb0 Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82 Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63 add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195 CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Awni Hannun
d6383a1c6a version bump (#1239) 2024-06-27 10:43:13 -07:00
Angelos Katharopoulos
b05bcfd27f Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62 Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1 fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8 Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439 Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
David Koski
4eef1e8a3e fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06 Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
Awni Hannun
af9079cc1f version bump (#1212) 2024-06-14 11:28:51 -07:00
Jagrit Digani
2d6cd47713 Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
Awni Hannun
fe3167d7ea smaller CPU binary (#1203)
* smaller CPU binary

* fix no cpu build
2024-06-14 09:46:55 -07:00
Awni Hannun
31e134be35 Build for macOS 15 (#1208)
* Build for macos 15

* metal32 as well

* comment

---------

Co-authored-by: Awni Hannun <Awni Hannun>
2024-06-13 13:31:44 -07:00
Awni Hannun
e84ba8056d only allow openmpi (#1209) 2024-06-13 12:14:44 -07:00
Fangjun Kuang
f20e97b092 minor fixes (#1194)
* minor fixes

* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a Fix kernel deps to reduce build times (#1205) 2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29 Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Awni Hannun
709ccc6800 install mpi for release build (#1199) 2024-06-10 10:09:32 -07:00
Awni Hannun
cf236fc390 version (#1191) 2024-06-06 17:16:40 -07:00
Alex Barron
27d70c7d9d Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409 Add docstring for scatter (#1189)
* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893 Fix the hard-shrink test (#1185) 2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4 Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
nicolov
81def6ac76 Fix benchmark (#1175) 2024-06-04 07:50:46 -07:00
Angelos Katharopoulos
3de8ce3f3c In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24 Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30 Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Dominik Schlösser
3576b547c5 Doc error for default for scale in SinusoidalPositionalEncoding (#1174) 2024-06-02 13:42:45 -07:00
Awni Hannun
079882495d version bump (#1172) 2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46 Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
372 changed files with 34445 additions and 14360 deletions

View File

@@ -13,8 +13,62 @@ parameters:
test_release: test_release:
type: boolean type: boolean
default: false default: false
linux_release:
type: boolean
default: false
jobs: jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test: linux_build_and_test:
docker: docker:
- image: cimg/python:3.9 - image: cimg/python:3.9
@@ -31,19 +85,24 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 pip install nanobind==2.2.0
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
echo "stubs" echo "stubs"
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
@@ -52,7 +111,9 @@ jobs:
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run: - run:
name: Run CPP tests name: Run CPP tests
command: ./build/tests/tests command: ./build/tests/tests
@@ -70,13 +131,13 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.8 brew install python@3.9
brew install openmpi brew install openmpi
python3.8 -m venv env python3.9 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 pip install nanobind==2.2.0
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@@ -85,11 +146,12 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
@@ -97,7 +159,7 @@ jobs:
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
@@ -111,7 +173,7 @@ jobs:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run: - run:
name: Run CPP tests name: Run CPP tests
command: | command: |
@@ -121,8 +183,23 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
make -j -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release: build_release:
parameters: parameters:
@@ -144,11 +221,12 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@<< parameters.python_version >> brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env python<< parameters.python_version >> -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install twine pip install twine
@@ -158,19 +236,20 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> \ << parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
@@ -183,7 +262,7 @@ jobs:
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
build_linux_test_release: build_linux_release:
parameters: parameters:
python_version: python_version:
type: string type: string
@@ -212,21 +291,28 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v pip install . -v
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -244,8 +330,9 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation
build_pypi_release: build_pypi_release:
when: when:
@@ -262,9 +349,17 @@ workflows:
ignore: /.*/ ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb: prb:
when: when:
matches: matches:
@@ -279,7 +374,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -291,7 +386,7 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
weekly_build: weekly_build:
when: when:
@@ -302,17 +397,17 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
linux_test_release: linux_test_release:
when: when:
and: and:
- equal: [ main, << pipeline.git.branch >> ] - equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >> - << pipeline.parameters.linux_release >>
jobs: jobs:
- build_linux_test_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]

View File

@@ -1,11 +1,11 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4 rev: v18.1.8
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2 rev: 24.8.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
@@ -14,3 +14,7 @@ repos:
- id: isort - id: isort
args: args:
- --profile=black - --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -7,16 +7,18 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

24
CITATION.cff Normal file
View File

@@ -0,0 +1,24 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -24,32 +24,34 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.14.0) set(MLX_VERSION 0.21.0)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}") message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF) if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR message(
"Building for x86_64 on macOS is not supported." FATAL_ERROR
" If you are on an Apple silicon system, check the build" "Building for x86_64 on macOS is not supported."
" documentation for possible fixes: " " If you are on an Apple silicon system, check the build"
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source") " documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
@@ -61,66 +63,59 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal) set(METAL_LIB "-framework Metal")
find_library(FOUNDATION_LIB Foundation) set(FOUNDATION_LIB "-framework Foundation")
find_library(QUARTZ_LIB QuartzCore) set(QUARTZ_LIB "-framework QuartzCore")
endif() endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF) set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL) elseif(MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
endif() endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" execute_process(
OUTPUT_VARIABLE MACOS_VERSION COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") if(${MACOS_SDK_VERSION} LESS 14.0)
message(
if (${MACOS_VERSION} GREATER_EQUAL 14.2) FATAL_ERROR
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff) "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif() endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
FetchContent_Declare( set(METAL_CPP_URL
metal_cpp https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
URL ${METAL_CPP_URL}
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
) )
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> $<INSTALL_INTERFACE:include/metal_cpp>)
$<INSTALL_INTERFACE:include/metal_cpp> target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
)
target_link_libraries(
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
endif() endif()
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -132,120 +127,135 @@ if (MLX_BUILD_CPU)
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
set(BLA_VENDOR OpenBLAS) set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif() endif()
# Search and link with lapack. # Search and link with lapack.
find_package(LAPACK REQUIRED) find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND) if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed") message(FATAL_ERROR "Must have LAPACK installed")
endif() endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/include /usr/local/opt/openblas/include)
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version # List blas after lapack otherwise we may accidentally incldue an old
# of lapack.h from the include dirs of blas. # version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND) if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed") message(FATAL_ERROR "Must have BLAS installed")
endif() endif()
# TODO find a cleaner way to do this # TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
/usr/include $ENV{BLAS_HOME}/include)
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES}) target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
endif() endif()
else() else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI) find_package(MPI)
if (MPI_FOUND) if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
mlx mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
PUBLIC $<INSTALL_INTERFACE:include>)
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 GIT_TAG 10.2.1
EXCLUDE_FROM_ALL EXCLUDE_FROM_ALL)
)
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if (MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()
if (MLX_BUILD_TESTS) if(MLX_BUILD_TESTS)
include(CTest) include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif() endif()
if (MLX_BUILD_EXAMPLES) if(MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif() endif()
if (MLX_BUILD_BENCHMARKS) if(MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif() endif()
# ----------------------------- Installation ----------------------------- # ----------------------------- Installation -----------------------------
include(GNUInstallDirs) include(GNUInstallDirs)
# Install library # Install library
install( install(
TARGETS mlx TARGETS mlx
EXPORT MLXTargets EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} INCLUDES
) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
# Install headers # Install headers
install( install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers COMPONENT headers
FILES_MATCHING PATTERN "*.h" FILES_MATCHING
) PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies # Install metal dependencies
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
# Install metal cpp # Install metal cpp
install( install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/ DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source COMPONENT metal_cpp_source)
)
endif() endif()
@@ -257,31 +267,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install( install(
EXPORT MLXTargets EXPORT MLXTargets
FILE MLXTargets.cmake FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
write_basic_package_version_file( write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION} VERSION ${MLX_VERSION})
)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
) MLX_CMAKE_INSTALL_MODULE_DIR)
install( install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install( install(DIRECTORY ${CMAKE_MODULE_PATH}/
DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon, MLX is an array framework for machine learning on Apple silicon,
brought to you by Apple machine learning research. brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:

View File

@@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys) mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
for i in range(100): for i in range(100):
@@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError("Unknown benchmark") raise ValueError("Unknown benchmark")

View File

@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor: def mish(x: torch.Tensor) -> torch.Tensor:
y = x y = x
for _ in range(100): for _ in range(100):
return torch.nn.functional.mish(y) y = torch.nn.functional.mish(y)
sync_if_needed(x) sync_if_needed(x)
@@ -283,6 +283,14 @@ def topk(axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def selu(x): def selu(x):
y = x y = x
@@ -446,5 +454,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk": elif args.benchmark == "topk":
print(bench(topk, axis, x)) print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else: else:
raise ValueError("Unknown benchmark") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs) result = run(*args, capture_output=True, **kwargs)
return float(result.stdout) return float(result.stdout)
except ValueError: except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}") raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args): def compare(args):

View File

@@ -9,7 +9,6 @@ from time_utils import time_fn
def bench_gelu(): def bench_gelu():
def gelu(x): def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@@ -51,7 +50,6 @@ def bench_gelu():
def bench_layernorm(): def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16) weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16) bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias) mx.eval(weight, bias)

View File

@@ -0,0 +1,127 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,129 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,110 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,143 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C) scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(

View File

@@ -0,0 +1,135 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,66 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@@ -0,0 +1,84 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -3,6 +3,8 @@
import matplotlib import matplotlib
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import sympy
import torch
from time_utils import measure_runtime from time_utils import measure_runtime
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size): def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft(x): def fft_mlx(x):
out = mx.fft.fft(x) if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out) mx.eval(out)
return out return out
bandwidths = [] def fft_mps(x):
for k in range(4, 12): if dim == 1:
n = 2**k out = torch.fft.fft(x)
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32) elif dim == 2:
x = x.astype(mx.complex64) out = torch.fft.fft2(x)
mx.eval(x) torch.mps.synchronize()
runtime_ms = measure_runtime(fft, x=x) return out
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft(): def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
with mx.stream(mx.cpu): print("MLX GPU")
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu): with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29)) gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
# plot bandwidths print("MPS GPU")
x = [2**k for k in range(4, 12)] mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU") print("CPU")
plt.title("MLX FFT Benchmark") system_size = int(2**20)
plt.xlabel("N") with mx.stream(mx.cpu):
plt.ylabel("Bandwidth (GB/s)") cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
plt.legend()
plt.savefig("fft_plot.png") x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,70 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx): def scatter(dst, x, idx):
dst[*idx] = x dst[tuple(idx)] = x
mx.eval(dst) mx.eval(dst)
idx = [] idx = []
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device): def scatter(dst, x, idx, device):
dst[*idx] = x dst[tuple(idx)] = x
if device == torch.device("mps"): if device == torch.device("mps"):
torch.mps.synchronize() torch.mps.synchronize()
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device) x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device) runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms") print(f"PyTorch: {runtime:.3f}ms")
@@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64), (100_000, 64),
(1_000_000, 64), (1_000_000, 64),
(100_000,), (100_000,),
(2_000_00,), (200_000,),
(20_000_000,), (20_000_000,),
(10000, 64), (10000, 64),
(100, 64), (100, 64),
@@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20) print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}") print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,189 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args):
for i in range(N_warmup):
f(*args)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -0,0 +1,94 @@
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 32768
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
bits = 8
loops = 20
def attention(q, k, v):
for _ in range(loops):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
ke = k[:, :, None, :, :]
ve = v[:, :, None, :, :]
s = q @ ke.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
q = p @ ve
q = q.reshape(B, Hq, L, D)
return q
def sdpa(q, k, v):
for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
return q
def quant_sdpa(q, k, v, bits=4):
for _ in range(loops):
q = mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
return q
def quant_attention(q, k, v, bits=4):
for _ in range(loops):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -1,30 +1,21 @@
include(CMakeParseArguments) include(CMakeParseArguments)
############################################################################### # ##############################################################################
# Build metal library # Build metal library
# #
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
# #
# Args: # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# TARGET: Custom target to be added for the metal library # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# TITLE: Name of the .metallib # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib # files (like headers)
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
# #
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments( cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output # Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
@@ -35,22 +26,16 @@ macro(mlx_build_metallib)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET} OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal COMMAND
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>" xcrun -sdk macosx metal
${MTLLIB_COMPILE_OPTIONS} "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_SOURCES} ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib" COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM VERBATIM)
)
# Add metallib custom target # Add metallib custom target
add_custom_target( add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib) endmacro(mlx_build_metallib)

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,3 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
mlx

View File

@@ -60,6 +60,7 @@ html_theme_options = {
}, },
} }
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
@@ -83,3 +84,15 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")] latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -0,0 +1,427 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -486,16 +486,15 @@ below.
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it // Make sure the metal library is available
// in the same folder as this executable if needed d.register_library("mlx_ext");
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -510,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed // threads in any given threadgroup is not higher than the max allowed
@@ -531,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer Attention layer
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to key/value cache that will be concatenated with the provided keys and values to
support efficient inference. support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader `mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which <https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as `mnist`. we will import as ``mnist``.
.. code-block:: python .. code-block:: python

View File

@@ -43,6 +43,7 @@ are the CPU and GPU.
usage/function_transforms usage/function_transforms
usage/compile usage/compile
usage/numpy usage/numpy
usage/distributed
usage/using_streams usage/using_streams
.. toctree:: .. toctree::
@@ -69,6 +70,7 @@ are the CPU and GPU.
python/metal python/metal
python/nn python/nn
python/optimizers python/optimizers
python/distributed
python/tree_utils python/tree_utils
.. toctree:: .. toctree::
@@ -83,3 +85,4 @@ are the CPU and GPU.
dev/extensions dev/extensions
dev/metal_debugger dev/metal_debugger
dev/custom_metal_kernels

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements: To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.8 - Using a native Python >= 3.9
- macOS >= 13.5 - macOS >= 13.5
.. note:: .. note::
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip: Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
For developing use an editable install: For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell .. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
To make sure the install is working run the tests with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with:
.. code-block:: shell .. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your IDE: Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell .. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs python setup.py generate_stubs
C++ API C++ API
@@ -186,8 +186,8 @@ should point to the path to the built metal library.
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel` To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and `BUILD_SHARED_LIBS=ON`. and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries. The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and For example, if you don't need the CPU backend or support for safetensors and
@@ -195,7 +195,7 @@ GGUF, you can do:
.. code-block:: shell .. code-block:: shell
cmake .. cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \ -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_CPU=OFF \
@@ -203,13 +203,13 @@ GGUF, you can do:
-DMLX_BUILD_GGUF=OFF \ -DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON -DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots. Metal kernel cache persists across reboots.
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell: .. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively. Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm, To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"`` If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported." but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again. wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -24,6 +24,7 @@ Array
array.any array.any
array.argmax array.argmax
array.argmin array.argmin
array.conj
array.cos array.cos
array.cummax array.cummax
array.cummin array.cummin
@@ -52,8 +53,10 @@ Array
array.sqrt array.sqrt
array.square array.square
array.squeeze array.squeeze
array.swapaxes array.std
array.sum array.sum
array.swapaxes
array.transpose array.transpose
array.T array.T
array.var array.var
array.view

View File

@@ -0,0 +1,22 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather
send
recv
recv_like

View File

@@ -12,3 +12,4 @@ Fast
layer_norm layer_norm
rope rope
scaled_dot_product_attention scaled_dot_product_attention
metal_kernel

View File

@@ -9,7 +9,12 @@ Linear Algebra
:toctree: _autosummary :toctree: _autosummary
inv inv
tri_inv
norm norm
cholesky cholesky
cholesky_inv
cross
qr qr
svd svd
eigvalsh
eigh

View File

@@ -14,6 +14,7 @@ Metal
get_cache_memory get_cache_memory
set_memory_limit set_memory_limit
set_cache_limit set_cache_limit
set_wired_limit
clear_cache clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -13,10 +13,13 @@ simple functions.
:template: nn-module-template.rst :template: nn-module-template.rst
elu elu
celu
gelu gelu
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx
glu glu
hard_shrink
hard_tanh
hardswish hardswish
leaky_relu leaky_relu
log_sigmoid log_sigmoid
@@ -29,6 +32,7 @@ simple functions.
sigmoid sigmoid
silu silu
softmax softmax
softmin
softplus softplus
softshrink softshrink
step step

View File

@@ -12,23 +12,37 @@ Layers
ALiBi ALiBi
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
AvgPool3d
BatchNorm BatchNorm
CELU
Conv1d Conv1d
Conv2d Conv2d
Conv3d Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout Dropout
Dropout2d Dropout2d
Dropout3d Dropout3d
Embedding Embedding
ELU
GELU GELU
GLU
GroupNorm GroupNorm
GRU GRU
HardShrink
HardTanh
Hardswish
InstanceNorm InstanceNorm
LayerNorm LayerNorm
LeakyReLU
Linear Linear
LogSigmoid
LogSoftmax
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
MaxPool3d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU
@@ -36,13 +50,20 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU6
RNN RNN
RoPE RoPE
SELU SELU
Sequential Sequential
Sigmoid
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softmin
Softshrink Softshrink
Softsign
Softmax
Softplus
Step Step
Tanh
Transformer Transformer
Upsample Upsample

View File

@@ -44,6 +44,10 @@ Operations
convolve convolve
conv1d conv1d
conv2d conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general conv_general
cos cos
cosh cosh
@@ -57,6 +61,8 @@ Operations
diagonal diagonal
divide divide
divmod divmod
einsum
einsum_path
equal equal
erf erf
erfinv erfinv
@@ -72,8 +78,11 @@ Operations
gather_qmm gather_qmm
greater greater
greater_equal greater_equal
hadamard_transform
identity identity
imag
inner inner
isfinite
isclose isclose
isinf isinf
isnan isnan
@@ -103,6 +112,7 @@ Operations
minimum minimum
moveaxis moveaxis
multiply multiply
nan_to_num
negative negative
not_equal not_equal
ones ones
@@ -112,14 +122,17 @@ Operations
pad pad
power power
prod prod
put_along_axis
quantize quantize
quantized_matmul quantized_matmul
radians radians
real
reciprocal reciprocal
remainder remainder
repeat repeat
reshape reshape
right_shift right_shift
roll
round round
rsqrt rsqrt
save save
@@ -156,6 +169,7 @@ Operations
tril tril
triu triu
var var
view
where where
zeros zeros
zeros_like zeros_like

View File

@@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state. # Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree:: .. toctree::
optimizers/optimizer optimizers/optimizer

View File

@@ -44,3 +44,5 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split split
truncated_normal truncated_normal
uniform uniform
laplace
permutation

View File

@@ -10,6 +10,7 @@ Transforms
eval eval
compile compile
custom_function
disable_compile disable_compile
enable_compile enable_compile
grad grad

View File

@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging Debugging
--------- ---------

View File

@@ -0,0 +1,166 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096)) ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys): def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])] return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition: Instead you can use :func:`vmap` to automatically vectorize the addition:
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
# Vectorize over the second dimension of x and the # Vectorize over the second dimension of x and the
# first dimension of y # first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster. vectorized version takes only ``0.024`` seconds, more than 200 times faster.
Of course, this operation is quite contrived. A better approach is to simply do Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient. kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs general, MLX has limited support for operations for which output
*shapes* are dependent on input *data*. Other examples of these types of *shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`. single input version of :func:`numpy.where`.

View File

@@ -109,7 +109,7 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`, :obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array. saving functions) will also evaluate the array.

View File

@@ -3,7 +3,11 @@
Conversion to NumPy and Other Frameworks Conversion to NumPy and Other Frameworks
======================================== ========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_. MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back. Let's convert an array to NumPy and back.
.. code-block:: python .. code-block:: python

View File

@@ -16,7 +16,7 @@ int main() {
std::cout << global_group.rank() << " / " << global_group.size() << std::endl; std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10}); array x = ones({10});
array out = distributed::all_reduce_sum(x, global_group); array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl; std::cout << out << std::endl;
} }

View File

@@ -11,10 +11,14 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
@@ -24,16 +28,10 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext) add_library(mlx_ext)
# Add sources # Add sources
target_sources( target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers # Add include headers
target_include_directories( target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
@@ -43,27 +41,32 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET mlx_ext_metallib TARGET
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib mlx_ext_metallib
) TITLE
mlx_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib)
endif() endif()
# ----------------------------- Python Bindings ----------------------------- # ----------------------------- Python Bindings -----------------------------
nanobind_add_module( nanobind_add_module(
_ext _ext
NB_STATIC STABLE_ABI LTO NOMINSIZE NB_STATIC
NB_DOMAIN mlx STABLE_ABI
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp LTO
) NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)

View File

@@ -249,16 +249,15 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out); kname << type_to_name(out);
// Make sure the metal library is available and look for it // Make sure the metal library is available
// in the same folder as this executable if needed d.register_library("mlx_ext");
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -273,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim if needed // Encode shape, strides and ndim if needed
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
} }
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
@@ -296,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available

View File

@@ -2,7 +2,6 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T> template <typename T>

View File

@@ -2,7 +2,7 @@
requires = [ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.24", "cmake>=3.24",
"mlx>=0.9.0", "mlx>=0.18.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4", "nanobind==2.2.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.9.0 mlx>=0.21.0
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 nanobind==2.2.0

View File

@@ -13,7 +13,6 @@ if __name__ == "__main__":
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False, zip_safe=False,
python_requires=">=3.8", python_requires=">=3.8",
) )

View File

@@ -28,10 +28,19 @@ endif()
if (@MLX_BUILD_METAL@) if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set_and_check(MLX_INCLUDE_DIRS set(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS} "${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
) )
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif() endif()
set_target_properties(mlx PROPERTIES set_target_properties(mlx PROPERTIES

View File

@@ -1,25 +1,24 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
)
if (MLX_BUILD_CPU) if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
@@ -27,17 +26,15 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU) elseif(MLX_BUILD_CPU)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif() endif()
if (MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@@ -19,15 +19,26 @@ Buffer malloc(size_t size) {
} }
void free(Buffer buffer) { void free(Buffer buffer) {
return allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)}; void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
} }
void CommonAllocator::free(Buffer buffer) { void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr()); std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
} }
Buffer malloc_or_wait(size_t size) { Buffer malloc_or_wait(size_t size) {

View File

@@ -41,6 +41,7 @@ class Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private: private:
CommonAllocator() = default; CommonAllocator() = default;

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
#include <unordered_map>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@@ -17,6 +18,10 @@ bool in_tracing() {
return detail::InTracing::in_tracing(); return detail::InTracing::in_tracing();
} }
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace } // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -91,18 +96,34 @@ void array::detach() {
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
void array::eval() { bool array::is_available() const {
// Ensure the array is ready to be read if (status() == Status::available) {
if (status() == Status::scheduled) { return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait(); event().wait();
set_status(Status::available); set_status(Status::available);
} else if (status() == Status::unscheduled) { }
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this}); mlx::core::eval({*this});
} else {
wait();
} }
} }
bool array::is_tracer() const { bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing(); return array_desc_->is_tracer && in_tracing() || retain_graph();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -158,8 +179,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags; array_desc_->flags = flags;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset; auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>( auto data_ptr = other.array_desc_->data_ptr;
static_cast<char*>(other.array_desc_->data_ptr) + char_offset); other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(array other) {
@@ -171,10 +194,11 @@ array::~array() {
return; return;
} }
// Ignore arrays that will be detached // Ignore arrays that might be detached during eval
if (status() != array::Status::unscheduled) { if (status() == array::Status::scheduled) {
return; return;
} }
// Break circular reference for non-detached arrays with siblings // Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) { if (auto n = siblings().size(); n > 0) {
bool do_detach = true; bool do_detach = true;
@@ -191,6 +215,8 @@ array::~array() {
if (do_detach) { if (do_detach) {
for (auto& s : siblings()) { for (auto& s : siblings()) {
for (auto& ss : s.siblings()) { for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr; ss.array_desc_ = nullptr;
} }
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
@@ -206,7 +232,7 @@ void array::ArrayDesc::init() {
strides[i] = size; strides[i] = size;
size *= shape[i]; size *= shape[i];
} }
for (auto& in : inputs) { for (const auto& in : inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
} }
} }
@@ -231,31 +257,52 @@ array::ArrayDesc::ArrayDesc(
array::ArrayDesc::~ArrayDesc() { array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays // When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so // that may also destroy their corresponding descriptions and so on and so
// forth. // forth.
// //
// This calls recursively the destructor and can result in stack overflow, we // This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a // instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2. // max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion; std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) { auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
if (a.array_desc_.use_count() == 1) { std::unordered_map<std::uintptr_t, array> input_map;
for_deletion.push_back(std::move(a.array_desc_)); for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
} }
} ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) { while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays // top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector // with inputs have been moved into the vector
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top);
for (array& a : top->inputs) { // Clear out possible siblings to break circular references
if (a.array_desc_.use_count() == 1) { for (auto& s : top->siblings) {
for_deletion.push_back(std::move(a.array_desc_)); // Set to null here to avoid descending into top-level
} // array destructor for siblings
s.array_desc_ = nullptr;
} }
top->siblings.clear();
} }
} }

View File

@@ -73,32 +73,32 @@ class array {
this->array_desc_ = other.array_desc_; this->array_desc_ = other.array_desc_;
} }
return *this; return *this;
}; }
/** The size of the array's datatype in bytes. */ /** The size of the array's datatype in bytes. */
size_t itemsize() const { size_t itemsize() const {
return size_of(dtype()); return size_of(dtype());
}; }
/** The number of elements in the array. */ /** The number of elements in the array. */
size_t size() const { size_t size() const {
return array_desc_->size; return array_desc_->size;
}; }
/** The number of bytes in the array. */ /** The number of bytes in the array. */
size_t nbytes() const { size_t nbytes() const {
return size() * itemsize(); return size() * itemsize();
}; }
/** The number of dimensions of the array. */ /** The number of dimensions of the array. */
size_t ndim() const { size_t ndim() const {
return array_desc_->shape.size(); return array_desc_->shape.size();
}; }
/** The shape of the array as a vector of integers. */ /** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const { const std::vector<int>& shape() const {
return array_desc_->shape; return array_desc_->shape;
}; }
/** /**
* Get the size of the corresponding dimension. * Get the size of the corresponding dimension.
@@ -107,12 +107,12 @@ class array {
* bounds checking. */ * bounds checking. */
int shape(int dim) const { int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim); return shape().at(dim < 0 ? dim + ndim() : dim);
}; }
/** The strides of the array. */ /** The strides of the array. */
const std::vector<size_t>& strides() const { const std::vector<size_t>& strides() const {
return array_desc_->strides; return array_desc_->strides;
}; }
/** /**
* Get the stride of the corresponding dimension. * Get the stride of the corresponding dimension.
@@ -121,12 +121,12 @@ class array {
* bounds checking. */ * bounds checking. */
size_t strides(int dim) const { size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim); return strides().at(dim < 0 ? dim + ndim() : dim);
}; }
/** Get the arrays data type. */ /** Get the arrays data type. */
Dtype dtype() const { Dtype dtype() const {
return array_desc_->dtype; return array_desc_->dtype;
}; }
/** Evaluate the array. */ /** Evaluate the array. */
void eval(); void eval();
@@ -160,10 +160,10 @@ class array {
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx; return a.arr.id() == b.arr.id() && a.idx == b.idx;
}; }
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b); return !(a == b);
}; }
private: private:
const array& arr; const array& arr;
@@ -209,7 +209,7 @@ class array {
allocator::Buffer buffer; allocator::Buffer buffer;
deleter_t d; deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free) Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}; : buffer(buffer), d(d) {}
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
@@ -219,33 +219,45 @@ class array {
}; };
struct Flags { struct Flags {
// True if there are no gaps in the underlying data. Each item // True iff there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index. // in the underlying data buffer belongs to at least one index.
//
// True iff:
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
bool contiguous : 1; bool contiguous : 1;
// True iff:
// strides[-1] == 1 and
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
// range(ndim - 1))
bool row_contiguous : 1; bool row_contiguous : 1;
// True iff:
// strides[0] == 1 and
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
// range(1, ndim))
bool col_contiguous : 1; bool col_contiguous : 1;
}; };
/** The array's primitive. */ /** The array's primitive. */
Primitive& primitive() const { Primitive& primitive() const {
return *(array_desc_->primitive); return *(array_desc_->primitive);
}; }
/** A shared pointer to the array's primitive. */ /** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const { std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive; return array_desc_->primitive;
}; }
/** Check if the array has an attached primitive or is a leaf node. */ /** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const { bool has_primitive() const {
return array_desc_->primitive != nullptr; return array_desc_->primitive != nullptr;
}; }
/** The array's inputs. */ /** The array's inputs. */
const std::vector<array>& inputs() const { const std::vector<array>& inputs() const {
return array_desc_->inputs; return array_desc_->inputs;
}; }
std::vector<array>& inputs() { std::vector<array>& inputs() {
return array_desc_->inputs; return array_desc_->inputs;
@@ -259,12 +271,12 @@ class array {
/** The array's siblings. */ /** The array's siblings. */
const std::vector<array>& siblings() const { const std::vector<array>& siblings() const {
return array_desc_->siblings; return array_desc_->siblings;
}; }
/** The array's siblings. */ /** The array's siblings. */
std::vector<array>& siblings() { std::vector<array>& siblings() {
return array_desc_->siblings; return array_desc_->siblings;
}; }
void set_siblings(std::vector<array> siblings, uint16_t position) { void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings); array_desc_->siblings = std::move(siblings);
@@ -281,7 +293,7 @@ class array {
outputs.push_back(*this); outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs; return outputs;
}; }
/** Detach the array from the graph. */ /** Detach the array from the graph. */
void detach(); void detach();
@@ -289,19 +301,32 @@ class array {
/** Get the Flags bit-field. */ /** Get the Flags bit-field. */
const Flags& flags() const { const Flags& flags() const {
return array_desc_->flags; return array_desc_->flags;
}; }
/** The size (in elements) of the underlying buffer the array points to. */ /** The size (in elements) of the underlying buffer the array points to.
*
* This can be different than the actual size of the array if the array has
* been broadcast or irregularly strided. If ``first`` is the offset into
* the data buffer of the first element of the array (i.e. the offset
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
* data buffer of the last element of the array (i.e. the offset
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const { size_t data_size() const {
return array_desc_->data_size; return array_desc_->data_size;
}; }
allocator::Buffer& buffer() { allocator::Buffer& buffer() {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; }
const allocator::Buffer& buffer() const { const allocator::Buffer& buffer() const {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; }
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer // Return a copy of the shared pointer
// to the array::Data struct // to the array::Data struct
@@ -312,19 +337,42 @@ class array {
template <typename T> template <typename T>
T* data() { T* data() {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
}; }
template <typename T> template <typename T>
const T* data() const { const T* data() const {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
}
enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
}; };
enum Status { unscheduled, scheduled, available }; // Check if the array is safe to read.
bool is_available() const;
bool is_available() const { // Wait on the array to be available. After this `is_available` returns
return status() == Status::available; // `true`.
} void wait();
const Status status() const {
Status status() const {
return array_desc_->status; return array_desc_->status;
} }
@@ -411,8 +459,6 @@ class array {
void* data_ptr{nullptr}; void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses // The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size; size_t data_size;
// Contains useful meta data about the array // Contains useful meta data about the array

View File

@@ -1,10 +1,8 @@
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -1,9 +1,9 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -2,8 +2,7 @@
#include <cassert> #include <cassert>
#include <vecLib/BNNS/bnns.h> #include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h" #include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"

View File

@@ -3,8 +3,7 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <vecLib/vDSP.h> #include <Accelerate/Accelerate.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
@@ -37,7 +36,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Conjugate) DEFAULT(Conjugate)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -51,6 +50,7 @@ DEFAULT(GatherMM)
DEFAULT(GatherQMM) DEFAULT(GatherQMM)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@@ -102,7 +103,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -117,7 +118,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -132,7 +133,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x + y; }); eval(inputs, out);
} }
} }
@@ -287,7 +288,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == int32) { if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -300,7 +301,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
}); });
} else if (a.dtype() == float32) { } else if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -315,7 +316,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x / y; }); eval(inputs, out);
} }
} }
@@ -326,12 +327,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
} }
} }
@@ -393,12 +390,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
} }
} }
@@ -408,7 +401,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -423,7 +416,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x * y; }); eval(inputs, out);
} }
} }
@@ -434,7 +427,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else { } else {
unary(in, out, [](auto x) { return -x; }); eval(inputs, out);
} }
} }
@@ -521,7 +514,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return x * x; }); eval(inputs, out);
} }
} }
@@ -547,7 +540,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -565,7 +558,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -577,7 +570,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
}, },
UseDefaultBinaryOp()); UseDefaultBinaryOp());
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x - y; }); eval(inputs, out);
} }
} }

View File

@@ -18,49 +18,61 @@ void _qmm_t_4_64(
const float* biases, const float* biases,
int M, int M,
int N, int N,
int K) { int K,
int B,
bool batched_w) {
constexpr int bits = 4; constexpr int bits = 4;
constexpr int group_size = 64; constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { int w_els = N * K / pack_factor;
const uint32_t* w_local = w; int g_els = w_els * pack_factor / group_size;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) { for (int i = 0; i < B; i++) {
const simd_float16* x_local = (simd_float16*)x; for (int m = 0; m < M; m++) {
simd_float16 sum = 0; const uint32_t* w_local = w;
for (int k = 0; k < K; k += group_size) { const float* scales_local = scales;
float scale = *scales_local++; const float* biases_local = biases;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) { for (int n = 0; n < N; n++) {
// TODO: vectorize this properly const simd_float16* x_local = (simd_float16*)x;
simd_uint16 wi; simd_float16 sum = 0;
for (int e = 0; e < 2; e++) { for (int k = 0; k < K; k += group_size) {
uint32_t wii = *w_local++; float scale = *scales_local++;
for (int p = 0; p < 8; p++) { float bias = *biases_local++;
wi[e * 8 + p] = wii & bitmask;
wii >>= bits; for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
} }
} simd_float16 wf = simd_float(wi);
simd_float16 wf = simd_float(wi); wf *= scale;
wf *= scale; wf += bias;
wf += bias;
sum += (*x_local) * wf; sum += (*x_local) * wf;
x_local++; x_local++;
}
} }
*result = simd_reduce_add(sum);
result++;
} }
*result = simd_reduce_add(sum); x += K;
result++; }
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
} }
x += K;
} }
} }
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) { if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64( _qmm_t_4_64(
out.data<float>(), out.data<float>(),
x.data<float>(), x.data<float>(),
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(), biases.data<float>(),
M, M,
N, N,
K); K,
B,
batched_w);
} else { } else {
eval(inputs, out); eval(inputs, out);
} }

View File

@@ -2,8 +2,8 @@
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -3,7 +3,10 @@
#include <cassert> #include <cassert>
#include <limits> #include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h> #include <arm_neon.h>
#endif
#include <simd/math.h> #include <simd/math.h>
#include <simd/vector.h> #include <simd/vector.h>
@@ -30,8 +33,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster * Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0. * implementations for numbers strictly < 0.
*/ */
inline simd_float16 simd_fast_exp(simd_float16 x) { inline simd_float16 simd_fast_exp(simd_float16 x_init) {
x *= 1.442695; // multiply with log_2(e) auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart; simd_float16 ipart, fpart;
simd_int16 epart; simd_int16 epart;
x = simd_clamp(x, -80, 80); x = simd_clamp(x, -80, 80);
@@ -50,28 +53,30 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
// bitshifting // bitshifting
epart = (simd_int(ipart) + 127) << 23; epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x; // Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** /**
* The ARM neon equivalent of the fast exp above. * The ARM neon equivalent of the fast exp above.
*/ */
inline float16x8_t neon_fast_exp(float16x8_t x) { inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e) x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14 x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14 x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5))); float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart); float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f); x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer // generate 2**ipart in the floating point representation using integer
// bitshifting // bitshifting
@@ -107,53 +112,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0); return vget_lane_f16(y, 0);
} }
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
};
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT> template <typename T, typename VT>
struct NeonFp16SimdOps { struct NeonFp16SimdOps {
VT init(T a) { VT init(T a) {
@@ -170,7 +128,7 @@ struct NeonFp16SimdOps {
VT max(VT a, VT b) { VT max(VT a, VT b) {
return vmaxq_f16(a, b); return vmaxq_f16(a, b);
}; }
VT exp(VT x) { VT exp(VT x) {
return neon_fast_exp(x); return neon_fast_exp(x);
@@ -201,6 +159,55 @@ struct NeonFp16SimdOps {
} }
}; };
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N> template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
Ops ops; Ops ops;
@@ -362,12 +369,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>, AccelerateSimdOps<float, simd_float16>,
16>(in, out); 16>(in, out);
} else { } else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax< softmax<
float16_t, float16_t,
float16_t, float16_t,
float16x8_t, float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>, NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out); 8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} }
break; break;
case bfloat16: case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <vecLib/BNNS/bnns.h> #include <Accelerate/Accelerate.h>
#include "mlx/dtype.h" #include "mlx/dtype.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -1,5 +1,4 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER}) set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE) set(CLANG TRUE)
else() else()
@@ -7,69 +6,57 @@ else()
endif() endif()
add_custom_command( add_custom_command(
OUTPUT compiled_preamble.cpp OUTPUT compiled_preamble.cpp
COMMAND /bin/bash COMMAND
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${COMPILER} ${PROJECT_SOURCE_DIR} ${CLANG}
${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh
${CLANG} compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
DEPENDS make_compiled_preamble.sh add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble) add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
) ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if (IOS) if(IOS)
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else() else()
target_sources( target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif() endif()

View File

@@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval(const std::vector<array>& inputs, array& out) { void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];

View File

@@ -43,13 +43,15 @@ void set_binary_op_output_data(
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt,
bool donate_with_move = false) { bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -64,7 +66,7 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
@@ -79,13 +81,13 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) { } else if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -100,16 +102,14 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous && if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if ( } else if (
b.is_donatable() && b.flags().row_contiguous && b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -122,19 +122,7 @@ void set_binary_op_output_data(
} }
} }
struct UseDefaultBinaryOp { struct UseDefaultBinaryOp {};
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
struct DefaultVectorScalar { struct DefaultVectorScalar {
@@ -150,18 +138,6 @@ struct DefaultVectorScalar {
a++; a++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -178,18 +154,6 @@ struct DefaultScalarVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -206,204 +170,110 @@ struct DefaultVectorVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims1(const array& a, const array& b, array& out, Op op) { void binary_op_dims(
const T* a_ptr = a.data<T>(); const T* a,
const T* b_ptr = b.data<T>(); const T* b,
U* dst = out.data<U>(); U* out,
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
Op op, Op op,
int stride) { const std::vector<int>& shape,
const T* a_ptr = a.data<T>(); const std::vector<size_t>& a_strides,
const T* b_ptr = b.data<T>(); const std::vector<size_t>& b_strides,
U* dst = out.data<U>(); const std::vector<size_t>& out_strides,
size_t a_idx = 0; int axis) {
size_t b_idx = 0; auto stride_a = a_strides[axis];
for (size_t i = 0; i < a.shape()[0]; i++) { auto stride_b = b_strides[axis];
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); auto stride_out = out_strides[axis];
a_idx += a.strides()[0]; auto N = shape[axis];
b_idx += b.strides()[0];
dst += stride;
}
}
template <typename T, typename U, typename Op> for (int i = 0; i < N; i++) {
void binary_op_dims2(const array& a, const array& b, array& out, Op op) { if constexpr (D > 1) {
const T* a_ptr = a.data<T>(); binary_op_dims<T, U, Op, D - 1, Strided>(
const T* b_ptr = b.data<T>(); a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
U* dst = out.data<U>(); } else {
size_t a_idx = 0; if constexpr (Strided) {
size_t b_idx = 0; op(a, b, out, stride_out);
size_t out_idx = 0; } else {
for (size_t i = 0; i < a.shape()[0]; ++i) { *out = op(*a, *b);
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
} }
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
} }
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; out += stride_out;
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; a += stride_a;
b += stride_b;
} }
} }
template <typename T, typename U, typename Op> template <typename T, typename U, bool Strided, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims( void binary_op_dispatch_dims(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
Op op, Op op,
int dim, int dim,
int stride) { const std::vector<int>& shape,
// Number of dimensions to loop over for vectorized ops const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) { switch (dim) {
case 1: case 1:
binary_op_dims1<T, U, Op>(a, b, out, op, stride); binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims2<T, U, Op>(a, b, out, op, stride); binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
const T* a_ptr = a.data<T>(); ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
const T* b_ptr = b.data<T>(); ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
U* dst = out.data<U>(); size_t stride = out_strides[dim - 4];
for (size_t i = 0; i < out.size(); i += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides()); binary_op_dims<T, U, Op, 3, Strided>(
int b_idx = elem_to_loc(i, b.shape(), b.strides()); a_ptr + a_it.loc,
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride); b_ptr + b_it.loc,
dst += stride; out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
} }
} }
@@ -450,29 +320,33 @@ void binary_op(
} }
// General computation so let's try to optimize // General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto& strides = out.strides(); auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
auto leftmost_rc_dim = [&strides](const array& arr) { int d = arr_strides.size() - 1;
int d = arr.ndim() - 1; for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_rc_dim = leftmost_rc_dim(a); auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b); auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after // Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) { auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
int d = arr.ndim() - 1; int d = arr_strides.size() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) { for (; d >= 0 && arr_strides[d] == 0; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_s_dim = leftmost_s_dim(a); auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b); auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = out.ndim(); auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim; int dim = ndim;
@@ -494,27 +368,27 @@ void binary_op(
// Can be sure dim > 0 since otherwise we would have used one of the fully // Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not // contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity. // correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) { if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General; bopt = BinaryOpType::General;
dim = ndim; dim = ndim;
} else {
stride = strides[dim - 1];
} }
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride); binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break; break;
default: default:
binary_op_dispatch_dims<T, U>(a, b, out, op); binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break; break;
} }
} }
@@ -531,9 +405,9 @@ void binary_op(
// TODO: The following mess of constexpr evaluations can probably be achieved // TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler? // with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?) // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -554,7 +428,8 @@ void binary_op(
DefaultVectorScalar<T, T, Op>(op), DefaultVectorScalar<T, T, Op>(op),
opvv); opvv);
} }
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp // opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -569,7 +444,8 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv); a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
} }
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp // opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
@@ -585,7 +461,8 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv); a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
} }
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp // opvv was UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op)); a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));

View File

@@ -9,168 +9,43 @@ namespace mlx::core {
namespace { namespace {
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, int D>
void binary_op_dims1( void binary_op_dims(
const array& a, const T* a,
const array& b, const T* b,
array& out_a, U* out_a,
array& out_b, U* out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out_a.size(); ++i) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op, Op op,
int stride) { const std::vector<int>& shape,
const T* a_ptr = a.data<T>(); const std::vector<size_t>& a_strides,
const T* b_ptr = b.data<T>(); const std::vector<size_t>& b_strides,
U* dst_a = out_a.data<U>(); const std::vector<size_t>& out_strides,
U* dst_b = out_b.data<U>(); int axis) {
size_t a_idx = 0; auto stride_a = a_strides[axis];
size_t b_idx = 0; auto stride_b = b_strides[axis];
for (size_t i = 0; i < a.shape()[0]; i++) { auto stride_out = out_strides[axis];
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride); auto N = shape[axis];
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
template <typename T, typename U, typename Op> for (int i = 0; i < N; i++) {
void binary_op_dims2( if constexpr (D > 1) {
const array& a, binary_op_dims<T, U, Op, D - 1>(
const array& b, a,
array& out_a, b,
array& out_b, out_a,
Op op) { out_b,
const T* a_ptr = a.data<T>(); op,
const T* b_ptr = b.data<T>(); shape,
U* dst_a = out_a.data<U>(); a_strides,
U* dst_b = out_b.data<U>(); b_strides,
size_t a_idx = 0; out_strides,
size_t b_idx = 0; axis + 1);
size_t out_idx = 0; } else {
for (size_t i = 0; i < a.shape()[0]; ++i) { std::tie(*out_a, *out_b) = op(*a, *b);
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
} }
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; a += stride_a;
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; b += stride_b;
} out_a += stride_out;
} out_b += stride_out;
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
} }
} }
@@ -181,352 +56,160 @@ void binary_op_dispatch_dims(
array& out_a, array& out_a,
array& out_b, array& out_b,
Op op) { Op op) {
switch (out_a.ndim()) { auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1: case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op); binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op); binary_op_dims<T, U, Op, 2>(
return; a_ptr,
case 3: b_ptr,
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op); out_a_ptr,
return; out_b_ptr,
case 4: op,
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op); shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
const T* a_ptr = a.data<T>(); ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
const T* b_ptr = b.data<T>(); ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
U* dst_a = out_a.data<U>(); size_t stride = out_strides[ndim - 3];
U* dst_b = out_b.data<U>(); for (size_t elem = 0; elem < a.size(); elem += stride) {
for (size_t i = 0; i < out_a.size(); i++) { binary_op_dims<T, U, Op, 2>(
int a_idx = elem_to_loc(i, a.shape(), a.strides()); a_ptr + a_it.loc,
int b_idx = elem_to_loc(i, b.shape(), b.strides()); b_ptr + b_it.loc,
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]); out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
} }
} }
template <typename T, typename U, typename Op> template <typename T, typename U = T, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
array& out_a, std::vector<array>& outputs,
array& out_b, Op op) {
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt); set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) = std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
op(*a.data<T>(), *b.data<T>()); } else if (bopt == BinaryOpType::ScalarVector) {
return; for (size_t i = 0; i < b.size(); ++i) {
} std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
// The full computation is scalar vector so delegate to the op out_b_ptr++;
if (bopt == BinaryOpType::ScalarVector) { b_ptr++;
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
} }
return d + 1; } else if (bopt == BinaryOpType::VectorScalar) {
}; for (size_t i = 0; i < a.size(); ++i) {
auto a_rc_dim = leftmost_rc_dim(a); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
auto b_rc_dim = leftmost_rc_dim(b); out_a_ptr++;
out_b_ptr++;
// Get the left-most dim such that the array is a broadcasted "scalar" after a_ptr++;
auto leftmost_s_dim = [](const array& arr) { }
int d = arr.ndim() - 1; } else { // VectorVector
for (; d >= 0 && arr.strides()[d] == 0; d--) { for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
} }
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
} }
} }
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV> template <typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary( void binary(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, std::vector<array>& outputs,
Ops... ops) { Op op) {
switch (outputs[0].dtype()) { switch (outputs[0].dtype()) {
case bool_: case bool_:
binary_op<bool>(a, b, outputs, ops...); binary_op<bool>(a, b, outputs, op);
break; break;
case uint8: case uint8:
binary_op<uint8_t>(a, b, outputs, ops...); binary_op<uint8_t>(a, b, outputs, op);
break; break;
case uint16: case uint16:
binary_op<uint16_t>(a, b, outputs, ops...); binary_op<uint16_t>(a, b, outputs, op);
break; break;
case uint32: case uint32:
binary_op<uint32_t>(a, b, outputs, ops...); binary_op<uint32_t>(a, b, outputs, op);
break; break;
case uint64: case uint64:
binary_op<uint64_t>(a, b, outputs, ops...); binary_op<uint64_t>(a, b, outputs, op);
break; break;
case int8: case int8:
binary_op<int8_t>(a, b, outputs, ops...); binary_op<int8_t>(a, b, outputs, op);
break; break;
case int16: case int16:
binary_op<int16_t>(a, b, outputs, ops...); binary_op<int16_t>(a, b, outputs, op);
break; break;
case int32: case int32:
binary_op<int32_t>(a, b, outputs, ops...); binary_op<int32_t>(a, b, outputs, op);
break; break;
case int64: case int64:
binary_op<int64_t>(a, b, outputs, ops...); binary_op<int64_t>(a, b, outputs, op);
break; break;
case float16: case float16:
binary_op<float16_t>(a, b, outputs, ops...); binary_op<float16_t>(a, b, outputs, op);
break; break;
case float32: case float32:
binary_op<float>(a, b, outputs, ops...); binary_op<float>(a, b, outputs, op);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, ops...); binary_op<bfloat16_t>(a, b, outputs, op);
break; break;
case complex64: case complex64:
binary_op<complex64_t>(a, b, outputs, ops...); binary_op<complex64_t>(a, b, outputs, op);
break; break;
} }
} }

View File

@@ -2,46 +2,12 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) { void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that // Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric: // the matrix should be symmetric:
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization. // Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N); int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite // TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how // because throwing an error would result in a crash. If we figure out how

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway. // rely on data_size anyway.
size_t data_size = out.size(); size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return move_or_copy(in, out, strides_, flags, data_size, offset_);
} }
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -58,21 +58,21 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) { if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false; flags.row_contiguous = flags.col_contiguous = false;
} }
out.copy_shared_buffer(in, strides, flags, in.data_size()); move_or_copy(in, out, strides, flags, in.data_size());
} }
void Copy::eval(const std::vector<array>& inputs, array& out) { void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]); move_or_copy(inputs[0], out);
} }
void CustomVJP::eval( void CustomTransforms::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) { i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]); move_or_copy(inputs[j], outputs[i]);
} }
} }
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]); move_or_copy(inputs[i], outputs[i]);
} }
} }
@@ -156,8 +156,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
} }
// Firstly let's collapse all the contiguous dimensions of the input // Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in); auto [shape, strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so // If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check. // let's check.
@@ -195,7 +194,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
} }
out.copy_shared_buffer(in, out_strides, flags, in.data_size()); move_or_copy(in, out, out_strides, flags, in.data_size());
} }
void Split::eval( void Split::eval(
@@ -250,49 +249,6 @@ void Split::eval(
} }
} }
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice( std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) { const array& in) {
int64_t data_offset = 0; int64_t data_offset = 0;
@@ -307,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]); move_or_copy(inputs[0], out);
} }
void Transpose::eval(const std::vector<array>& inputs, array& out) { void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -341,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri); b_stride *= out.shape(ri);
} }
} }
out.copy_shared_buffer(in, out_strides, flags, in.data_size()); move_or_copy(in, out, out_strides, flags, in.data_size());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
return print_int_constant<int8_t>(os, x); os << static_cast<int32_t>(x.item<int8_t>());
return;
case int16: case int16:
return print_int_constant<int16_t>(os, x); return print_int_constant<int16_t>(os, x);
case int32: case int32:
@@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
case int64: case int64:
return print_int_constant<int64_t>(os, x); return print_int_constant<int64_t>(os, x);
case uint8: case uint8:
return print_int_constant<uint8_t>(os, x); os << static_cast<uint32_t>(x.item<uint8_t>());
return;
case uint16: case uint16:
return print_int_constant<uint16_t>(os, x); return print_int_constant<uint16_t>(os, x);
case uint32: case uint32:
@@ -205,8 +207,8 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Correct size // - Correct size
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) { if (move_buffers) {
outputs[o].move_shared_buffer( outputs[o].move_shared_buffer(

View File

@@ -2,7 +2,10 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include <fstream>
#include <list> #include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
@@ -11,22 +14,7 @@
namespace mlx::core { namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in struct CompilerCache {
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib { struct DLib {
DLib(const std::string& libname) { DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW); lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -43,15 +31,41 @@ void* compile(
void* lib; void* lib;
}; };
// Statics to cache compiled libraries and functions // Statics to cache compiled libraries and functions
static std::list<DLib> libs; std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels; std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) { std::shared_mutex mtx;
return it->second; };
}
if (source_code.empty()) { static CompilerCache cache{};
return nullptr;
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
} }
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name; std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255 // Deal with long kernel names. Maximum length for files on macOS is 255
@@ -89,8 +103,8 @@ void* compile(
source_file.close(); source_file.close();
std::ostringstream build_command; std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << " -o " << shared_lib_path; << source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str(); std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str()); auto return_code = system(build_command_str.c_str());
if (return_code) { if (return_code) {
@@ -102,10 +116,10 @@ void* compile(
} }
// load library // load library
libs.emplace_back(shared_lib_path); cache.libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -113,7 +127,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
kernels.insert({kernel_name, fun}); cache.kernels.insert({kernel_name, fun});
return fun; return fun;
} }
@@ -265,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape); auto contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;
@@ -315,10 +329,7 @@ void Compiled::eval_cpu(
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name); auto fn_ptr = compile(kernel_name, [&]() {
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -333,10 +344,8 @@ void Compiled::eval_cpu(
ndim); ndim);
// Close extern "C" // Close extern "C"
kernel << "}" << std::endl; kernel << "}" << std::endl;
return kernel.str();
// Compile and get function pointer });
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false); inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,13 +3,8 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -684,6 +679,32 @@ void dispatch_slow_conv_3D(
// Explicit gemm conv // Explicit gemm conv
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T>
void flip_spatial_dims_inplace(array& wt) {
T* x = wt.data<T>();
size_t out_channels = wt.shape(0);
size_t in_channels = wt.shape(-1);
// Calculate the total size of the spatial dimensions
int spatial_size = 1;
for (int d = 1; d < wt.ndim() - 1; ++d) {
spatial_size *= wt.shape(d);
}
for (size_t i = 0; i < out_channels; i++) {
T* top = x + i * spatial_size * in_channels;
T* bottom =
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
for (size_t j = 0; j < spatial_size / 2; j++) {
for (size_t k = 0; k < in_channels; k++) {
std::swap(top[k], bottom[k]);
}
top += in_channels;
bottom -= in_channels;
}
}
}
void explicit_gemm_conv_1D_cpu( void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
@@ -910,7 +931,8 @@ void explicit_gemm_conv_ND_cpu(
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) { const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>( const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
@@ -1000,6 +1022,14 @@ void explicit_gemm_conv_ND_cpu(
copy(wt, gemm_wt, ctype); copy(wt, gemm_wt, ctype);
} }
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy(gemm_wt, gemm_wt_, CopyType::Vector);
flip_spatial_dims_inplace<float>(gemm_wt_);
gemm_wt = gemm_wt_;
}
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
@@ -1042,10 +1072,15 @@ void conv_1D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation); in, wt, out, padding, wt_strides, wt_dilation);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
@@ -1060,6 +1095,13 @@ void conv_2D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1073,6 +1115,14 @@ void conv_3D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1125,7 +1175,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
else { else {
std::ostringstream msg; std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports" msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2 << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions"; << " spatial dimensions";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -4,6 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -25,252 +26,117 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT, typename stride_t> template <typename SrcT, typename DstT, typename StrideT, int D>
void copy_general_dim1( inline void copy_dims(
const array& src, const SrcT* src,
array& dst, DstT* dst,
const std::vector<int>& data_shape, const std::vector<int>& shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
int64_t i_offset) { const std::vector<StrideT>& o_strides,
const SrcT* src_ptr = src.data<SrcT>(); int axis) {
DstT* dst_ptr = dst.data<DstT>(); auto stride_src = i_strides[axis];
stride_t src_idx = i_offset; auto stride_dst = o_strides[axis];
stride_t dst_idx = 0; auto N = shape[axis];
for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[0];
}
}
template <typename SrcT, typename DstT> for (int i = 0; i < N; i++) {
inline void copy_general_dim1(const array& src, array& dst) { if constexpr (D > 1) {
return copy_general_dim1<SrcT, DstT, size_t>( copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, src.shape(), src.strides(), 0); src, dst, shape, i_strides, o_strides, axis + 1);
} } else {
*dst = static_cast<DstT>(*src);
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
} }
src_idx += i_strides[0] - i_strides[1] * data_shape[1]; src += stride_src;
dst += stride_dst;
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename StrideT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
dst_ptr += stride_dst;
}
}
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<StrideT>& o_strides,
stride_t i_offset, int64_t i_offset,
stride_t o_offset) { int64_t o_offset) {
switch (src.ndim()) { if (data_shape.empty()) {
case 1: auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
copy_general_general_dims<SrcT, DstT, stride_t, 1>( auto dst_ptr = dst.data<DstT>() + o_offset;
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); *dst_ptr = val;
return; return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
} }
auto [shape, strides] = collapse_contiguous_dims(
int size = std::accumulate( data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>()); auto src_ptr = src.data<SrcT>() + i_offset;
for (int i = 0; i < src.size(); i += size) { auto dst_ptr = dst.data<DstT>() + o_offset;
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); int ndim = shape.size();
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); if (ndim == 1) {
copy_general_general_dims<SrcT, DstT, stride_t, 5>( copy_dims<SrcT, DstT, StrideT, 1>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>( copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
} }
template <typename SrcT, typename DstT, typename StrideT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
0,
0);
}
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) { switch (ctype) {
@@ -285,6 +151,7 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...); copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
} }
} }
@@ -385,7 +252,7 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
return copy_inplace_dispatch(src, dst, ctype); copy_inplace_dispatch(src, dst, ctype);
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype) {
@@ -415,20 +282,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename stride_t> template <typename StrideT>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<StrideT>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<StrideT>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype) {
switch (ctype) { switch (ctype) {
case CopyType::General: case CopyType::General:
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
return copy_inplace_dispatch( copy_inplace_dispatch(
src, src,
dst, dst,
ctype, ctype,
@@ -437,15 +304,24 @@ void copy_inplace(
o_strides, o_strides,
i_offset, i_offset,
o_offset); o_offset);
break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::Vector: case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype); copy_inplace_dispatch(src, dst, ctype);
} }
} }
template <> template void copy_inplace<size_t>(
void copy_inplace<int64_t>( const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
@@ -453,24 +329,6 @@ void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides, const std::vector<int64_t>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype);
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,15 +1,10 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -53,7 +48,7 @@ DEFAULT(Convolution)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT(Cos) DEFAULT(Cos)
DEFAULT(Cosh) DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -69,6 +64,7 @@ DEFAULT(Full)
DEFAULT(Gather) DEFAULT(Gather)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace { namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,105 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -81,11 +80,18 @@ void gather(
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0; size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]); auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax)); offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]); src_idx += (idx_val * src.strides()[ax]);
@@ -99,9 +105,10 @@ void gather(
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; src_it.step();
} }
src_it.reset();
} }
} }
} }
@@ -223,21 +230,29 @@ void scatter(
update_size *= us; update_size *= us;
} }
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < nind; ++j) { for (int j = 0; j < nind; ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]); auto idx_loc = its[j].loc;
its[j].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax)); offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]); out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates); op(updates.data<InT>()[update_it.loc],
auto out_loc = elem_to_loc(j, update_shape, out.strides()); out.data<InT>() + out_offset + out_it.loc);
op(updates.data<InT>()[update_loc], update_it.step();
out.data<InT>() + out_offset + out_loc); out_it.step();
} }
out_it.reset();
update_it.reset();
} }
} }

View File

@@ -2,17 +2,94 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
#include <Accelerate/Accelerate.h> int info;
#else MLX_LAPACK_FUNC(strtri)
#include <lapack.h> (
#endif /* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core { namespace mlx::core {
void inverse_impl(const array& a, array& inv) { void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following // Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see // identity to avoid transposing (see
// https://math.stackexchange.com/a/340234): // https://math.stackexchange.com/a/340234):
@@ -24,63 +101,11 @@ void inverse_impl(const array& a, array& inv) {
const int N = a.shape(-1); const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N); const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization. if (tri) {
sgetrf_( tri_inv(inv, N, i, upper);
/* m = */ &N, } else {
/* n = */ &N, general_inv(inv, N, i);
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
} }
} }
} }
@@ -89,7 +114,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32."); throw std::runtime_error("[Inverse::eval] only supports float32.");
} }
inverse_impl(inputs[0], output); inverse_impl(inputs[0], output, tri_, upper_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h>
#include <lapack.h> #include <lapack.h>
#endif #endif

View File

@@ -5,11 +5,9 @@
#include <utility> #include <utility>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/io/load.h" #include "mlx/backend/common/load.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core {
namespace { namespace {
template <const uint8_t scalar_size> template <const uint8_t scalar_size>
@@ -29,12 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
} // namespace } // namespace
void Load::eval(const std::vector<array>& inputs, array& out) { namespace mlx::core {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
reader_->seek(offset_, std::ios_base::beg); void load(
reader_->read(out.data<char>(), out.nbytes()); array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) { if (swap_endianness_) {
switch (out.itemsize()) { switch (out.itemsize()) {
@@ -51,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core } // namespace mlx::core

14
mlx/backend/common/load.h Normal file
View File

@@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -18,16 +18,19 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
EOM EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi fi
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null) CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() { const char* get_kernel_preamble() {
return R"preamble( return R"preamble(
$INCLUDES $INCLUDES
$CONTENT $CONTENT
using namespace mlx::core;
using namespace mlx::core::detail; using namespace mlx::core::detail;
)preamble"; )preamble";
} }

View File

@@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -108,105 +108,105 @@ struct Abs {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::abs(x); return std::abs(x);
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
}; };
struct ArcCos { struct ArcCos {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::acos(x); return std::acos(x);
}; }
}; };
struct ArcCosh { struct ArcCosh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::acosh(x); return std::acosh(x);
}; }
}; };
struct ArcSin { struct ArcSin {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::asin(x); return std::asin(x);
}; }
}; };
struct ArcSinh { struct ArcSinh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::asinh(x); return std::asinh(x);
}; }
}; };
struct ArcTan { struct ArcTan {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::atan(x); return std::atan(x);
}; }
}; };
struct ArcTan2 { struct ArcTan2 {
template <typename T> template <typename T>
T operator()(T y, T x) { T operator()(T y, T x) {
return std::atan2(y, x); return std::atan2(y, x);
}; }
}; };
struct ArcTanh { struct ArcTanh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::atanh(x); return std::atanh(x);
}; }
}; };
struct Ceil { struct Ceil {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::ceil(x); return std::ceil(x);
}; }
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; }
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; }
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; }
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
}; };
struct Conjugate { struct Conjugate {
@@ -219,35 +219,35 @@ struct Cos {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::cos(x); return std::cos(x);
}; }
}; };
struct Cosh { struct Cosh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::cosh(x); return std::cosh(x);
}; }
}; };
struct Erf { struct Erf {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x))); return static_cast<T>(fast_erf(static_cast<float>(x)));
}; }
}; };
struct ErfInv { struct ErfInv {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x))); return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}; }
}; };
struct Exp { struct Exp {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return fast_exp(x); return fast_exp(x);
}; }
complex64_t operator()(complex64_t x) { complex64_t operator()(complex64_t x) {
return std::exp(x); return std::exp(x);
@@ -258,83 +258,97 @@ struct Expm1 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return expm1(x); return expm1(x);
}; }
}; };
struct Floor { struct Floor {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::floor(x); return std::floor(x);
}; }
int8_t operator()(int8_t x) { int8_t operator()(int8_t x) {
return x; return x;
}; }
int16_t operator()(int16_t x) { int16_t operator()(int16_t x) {
return x; return x;
}; }
int32_t operator()(int32_t x) { int32_t operator()(int32_t x) {
return x; return x;
}; }
int64_t operator()(int64_t x) { int64_t operator()(int64_t x) {
return x; return x;
}; }
uint8_t operator()(uint8_t x) { uint8_t operator()(uint8_t x) {
return x; return x;
}; }
uint16_t operator()(uint16_t x) { uint16_t operator()(uint16_t x) {
return x; return x;
}; }
uint32_t operator()(uint32_t x) { uint32_t operator()(uint32_t x) {
return x; return x;
}; }
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x; return x;
}; }
bool operator()(bool x) { bool operator()(bool x) {
return x; return x;
}; }
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
}; };
struct Log { struct Log {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log(x); return std::log(x);
}; }
}; };
struct Log2 { struct Log2 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log2(x); return std::log2(x);
}; }
}; };
struct Log10 { struct Log10 {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::log10(x); return std::log10(x);
}; }
}; };
struct Log1p { struct Log1p {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return log1p(x); return log1p(x);
}; }
}; };
struct LogicalNot { struct LogicalNot {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return !x; return !x;
}; }
}; };
struct Negative { struct Negative {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return -x; return -x;
}; }
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
}; };
struct Round { struct Round {
@@ -373,55 +387,59 @@ struct Sign {
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x != 0; return x != 0;
} }
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
}; };
struct Sin { struct Sin {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sin(x); return std::sin(x);
}; }
}; };
struct Sinh { struct Sinh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sinh(x); return std::sinh(x);
}; }
}; };
struct Square { struct Square {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return x * x; return x * x;
}; }
}; };
struct Sqrt { struct Sqrt {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::sqrt(x); return std::sqrt(x);
}; }
}; };
struct Rsqrt { struct Rsqrt {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x); return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}; }
}; };
struct Tan { struct Tan {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::tan(x); return std::tan(x);
}; }
}; };
struct Tanh { struct Tanh {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
return std::tanh(x); return std::tanh(x);
}; }
}; };
struct Add { struct Add {
@@ -554,7 +572,7 @@ struct LogAddExp {
? maxval ? maxval
: static_cast<decltype(x)>( : static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval))); maxval + std::log1p(fast_exp(minval - maxval)));
}; }
}; };
struct Multiply { struct Multiply {
@@ -602,14 +620,14 @@ struct LogicalAnd {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x && y; return x && y;
}; }
}; };
struct LogicalOr { struct LogicalOr {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x || y; return x || y;
}; }
}; };
struct Select { struct Select {
@@ -623,35 +641,35 @@ struct BitwiseAnd {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x & y; return x & y;
}; }
}; };
struct BitwiseOr { struct BitwiseOr {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x | y; return x | y;
}; }
}; };
struct BitwiseXor { struct BitwiseXor {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x ^ y; return x ^ y;
}; }
}; };
struct LeftShift { struct LeftShift {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x << y; return x << y;
}; }
}; };
struct RightShift { struct RightShift {
template <typename T> template <typename T>
T operator()(T x, T y) { T operator()(T x, T y) {
return x >> y; return x >> y;
}; }
}; };
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@@ -8,9 +8,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/arange.h" #include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h" #include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h" #include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h" #include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -273,6 +284,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype); copy(in, out, ctype);
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) { void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -313,20 +328,6 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, detail::LogicalNot()); unary(in, out, detail::LogicalNot());
} }
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Negative::eval(const std::vector<array>& inputs, array& out) { void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@@ -412,6 +413,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) { void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -419,7 +424,8 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
@@ -492,7 +498,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in); auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed // Do copy if needed
if (copy_needed) { if (copy_needed) {
@@ -508,8 +515,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
/* int64_t o_offset = */ 0, /* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General); /* CopyType ctype = */ CopyType::General);
} else { } else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()}; std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out); shared_buffer_slice(in, ostrides, data_offset, data_size, out);
} }
} }
@@ -590,4 +605,43 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,14 +2,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>

View File

@@ -2,13 +2,38 @@
#include <cassert> #include <cassert>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size> template <typename T, int bits, int group_size>
void _qmm( void _qmm(
T* result, T* result,
@@ -20,13 +45,12 @@ void _qmm(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint32_t* w_local = w; const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -40,13 +64,25 @@ void _qmm(
T scale = *scales_local++; T scale = *scales_local++;
T bias = *biases_local++; T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) { for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++; if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
(*result_local++) += (*result_local++) += xi * (scale * wl[p] + bias);
xi * (scale * static_cast<T>(wi & bitmask) + bias); }
wi >>= bits; w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
} }
} }
} }
@@ -67,13 +103,12 @@ void _qmm_t(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint32_t* w_local = w; const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -85,12 +120,26 @@ void _qmm_t(
T bias = *biases_local++; T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) { for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++; if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias); sum += x_local[p] * (scale * wl[p] + bias);
wi >>= bits; }
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
} }
} }
} }
@@ -102,6 +151,55 @@ void _qmm_t(
} }
} }
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T> template <typename T>
void _qmm_dispatch_typed( void _qmm_dispatch_typed(
T* result, T* result,
@@ -116,79 +214,29 @@ void _qmm_dispatch_typed(
int bits, int bits,
bool transposed_w) { bool transposed_w) {
switch (bits) { switch (bits) {
case 2: { case 2:
switch (group_size) { _qmm_dispatch_group<T, 2>(
case 32: result, x, w, scales, biases, M, N, K, group_size, transposed_w);
if (transposed_w) { break;
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K); case 3:
} else { _qmm_dispatch_group<T, 3>(
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} break;
case 64: case 4:
if (transposed_w) { _qmm_dispatch_group<T, 4>(
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} else { break;
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K); case 6:
} _qmm_dispatch_group<T, 6>(
case 128: result, x, w, scales, biases, M, N, K, group_size, transposed_w);
if (transposed_w) { break;
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K); case 8:
} else { _qmm_dispatch_group<T, 8>(
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} break;
} default:
} throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
} }
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
} }
void _qmm_dispatch( void _qmm_dispatch(
@@ -201,55 +249,61 @@ void _qmm_dispatch(
int group_size, int group_size,
bool transposed_w) { bool transposed_w) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.shape(-2);
int N = out.shape(-1); int N = out.shape(-1);
switch (x.dtype()) { int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
case float32: int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
_qmm_dispatch_typed<float>(
out.data<float>(), int batch_size = x.size() / x.shape(-1) / x.shape(-2);
x.data<float>(), for (int i = 0; i < batch_size; i++) {
w.data<uint32_t>(), switch (x.dtype()) {
scales.data<float>(), case float32:
biases.data<float>(), _qmm_dispatch_typed<float>(
M, out.data<float>() + i * M * N,
N, x.data<float>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<float>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<float>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
case float16: K,
_qmm_dispatch_typed<float16_t>( bits,
out.data<float16_t>(), group_size,
x.data<float16_t>(), transposed_w);
w.data<uint32_t>(), break;
scales.data<float16_t>(), case float16:
biases.data<float16_t>(), _qmm_dispatch_typed<float16_t>(
M, out.data<float16_t>() + i * M * N,
N, x.data<float16_t>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
case bfloat16: K,
_qmm_dispatch_typed<bfloat16_t>( bits,
out.data<bfloat16_t>(), group_size,
x.data<bfloat16_t>(), transposed_w);
w.data<uint32_t>(), break;
scales.data<bfloat16_t>(), case bfloat16:
biases.data<bfloat16_t>(), _qmm_dispatch_typed<bfloat16_t>(
M, out.data<bfloat16_t>() + i * M * N,
N, x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
K, w.data<uint32_t>() + elem_to_loc(i * w_els, w),
bits, scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
group_size, biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
transposed_w); M,
break; N,
default: K,
throw std::invalid_argument( bits,
"[quantized_matmul] only floating types are supported"); group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
} }
} }
@@ -398,4 +452,114 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_); transpose_);
} }
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -87,130 +87,225 @@ struct OrReduce {
} }
}; };
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT> template <typename InT>
void reduce_dispatch_out( void reduce_dispatch_and_or(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes) { const std::vector<int>& axes) {
switch (rtype) { if (rtype == Reduce::And) {
case Reduce::And: { reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool>(in, out, axes, true, AndReduce()); } else {
break; reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
} }
case Reduce::Or: { } else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce()); auto op = [](auto y, auto x) { (*y) *= x; };
break; if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
} reduction_op<InT, int32_t>(in, out, axes, 1, op);
case Reduce::Sum: { } else {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
switch (out.dtype()) {
case bool_:
reduction_op<InT, bool>(in, out, axes, false, op);
break;
case uint8:
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
break;
case uint16:
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
break;
case uint32:
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
break;
case uint64:
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
break;
case int8:
reduction_op<InT, int8_t>(in, out, axes, 0, op);
break;
case int16:
reduction_op<InT, int16_t>(in, out, axes, 0, op);
break;
case int32:
reduction_op<InT, int32_t>(in, out, axes, 0, op);
break;
case int64:
reduction_op<InT, int64_t>(in, out, axes, 0, op);
break;
case float16:
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
break;
case float32:
reduction_op<InT, float>(in, out, axes, 0.0f, op);
break;
case bfloat16:
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
break;
case complex64:
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
break;
}
} break;
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op); reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, op);
break;
}
case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, op);
break;
} }
} }
} }
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
} // namespace } // namespace
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
void Reduce::eval(const std::vector<array>& inputs, array& out) { void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { switch (reduce_type_) {
case bool_: case Reduce::And:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_); case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint8: }
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint16: }
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case Reduce::Max:
break; case Reduce::Min: {
case uint32: switch (in.dtype()) {
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_); case bool_:
break; reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_); case uint8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case uint16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case int16: break;
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case uint32:
break; reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
case int32: break;
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_); case uint64:
break; reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
case int64: break;
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_); case int8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case float16: break;
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_); case int16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_out<float>(in, out, reduce_type_, axes_); case int32:
break; reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
case bfloat16: break;
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_); case int64:
break; reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
case complex64: break;
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_); case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
}
} }
} }

Some files were not shown because too many files have changed in this diff Show More