Compare commits

..

294 Commits

Author SHA1 Message Date
Awni Hannun
bf7cd29970 version bump (#698) 2024-02-16 08:44:08 -08:00
Nripesh Niketan
a000d2288c feat: update black pre-commit hook to 24.2.0 (#696) 2024-02-16 06:01:59 -08:00
Mike Drob
165abf0e4c Auto-run PRs from contributors (#692) 2024-02-15 17:30:35 -08:00
Srimukh Sripada
818cda16bc Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
toji
85143fecdd improved error msg for invalid axis(mx.split) (#685)
* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8 Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32 Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Awni Hannun
1eb04aa23f Fix empty array construction in cpp (#684) 2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91 Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0

* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00
Vijay Krish
2fdc2462c3 Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Hinrik Snær Guðmundsson
be6e9d6a9f Fixed wording in extensions.rst (#678)
changed "learn how add" -> "learn how to add"
2024-02-13 08:39:02 -08:00
Gabrijel Boduljak
e54cbb7ba6 Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
Angelos Katharopoulos
40c108766b Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Mike Drob
4cc70290f7 PR Builder Workflow (#659) 2024-02-12 17:47:21 -08:00
Awni Hannun
74caa68d02 nit in readme (#675) 2024-02-12 12:25:04 -08:00
Awni Hannun
3756381358 Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6 quote file name (#670) 2024-02-11 10:33:30 -08:00
Nripesh Niketan
0dbc4c7547 feat: Update pre-commit-config.yaml (#667) 2024-02-11 06:08:20 -08:00
Vijay Krish
06072601ce Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Angelos Katharopoulos
11d2c8f7a1 Linux build for CI of other packages (#660) 2024-02-09 18:17:04 -08:00
Awni Hannun
7f3f8d8f8d Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185 Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Angelos Katharopoulos
221f8d3fc2 Bump the version to 0.2 (#656) 2024-02-08 11:27:12 -08:00
Awni Hannun
5c03efaf29 Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
LeonEricsson
7dccd42133 updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Angelos Katharopoulos
28eac18571 Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Noah Farr
5fd11c347d Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Aryan Gupta
ef73393a19 Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
Angelos Katharopoulos
ea406d5e33 CI change (#645)
* CI update

* Skip large binary test for now

* Upgrade pip

* Add proper env variable skipping

* Update the CI

* Fix workflow name

* Set the low memory flag for the tests

* Change build process

* Add pip upgrade

* Use a venv

* Add a missing env activate

* Add setuptools

* Add twine upload back

* Re-enable automatic release builds
2024-02-07 06:04:34 -08:00
Awni Hannun
146bd69470 Skip compile when transforming (#635)
* skip compile when transforming

* simplify message
2024-02-05 21:28:37 -08:00
Jagrit Digani
316ff490b3 Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Avikant Srivastava
31fea3758e feat: enhancement of the error message for mlx.core.mean (#608)
* add error message
2024-02-05 01:21:49 -08:00
Awni Hannun
e319383ef9 Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
Awni Hannun
5c3ac52dd7 fix test (#627) 2024-02-04 16:18:03 -08:00
David Koski
ebfd3618b0 fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Avikant Srivastava
11a9fd40f0 fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Daniel Strobusch
4fd2fb84a6 make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19 fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
minghuaw
16750f3c51 Fix typo in CMakeLists.txt (#616) 2024-02-03 05:59:26 -08:00
Awni Hannun
95b5fb8245 minor changes (#613) 2024-02-02 11:48:35 -08:00
AtomicVar
83f63f2184 Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
Awni Hannun
cb6156d35d Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Piotr Rybiec
506d43035c typo fix (#607) 2024-02-01 17:39:55 -08:00
Angelos Katharopoulos
36cff34701 Bump the version (#604) 2024-02-01 11:41:38 -08:00
Awni Hannun
e88e474fd1 Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00
David Koski
601c6d6aa8 Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365 Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
Vijay Krish
fcc5ac1c64 Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
nathan
bad67fec37 Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
199aebcf77 Change the variance computation (#319) 2024-01-30 19:28:56 -08:00
Angelos Katharopoulos
0de5988f92 Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5 Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Jagrit Digani
375446453e Update Compute Pipeline Creation API (#581)
* Add option to specialize metal functions on function constants
* Update Compute Pipeline Creation API
* Add options to make libraries from source and stitching
* Update function specialization name options
2024-01-30 15:42:36 -08:00
Angelos Katharopoulos
1895d34c20 Fix log1p with inf inputs (#592) 2024-01-30 14:02:50 -08:00
Awni Hannun
09b9275027 Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454 Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jacket
3f7aba8498 Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Angelos Katharopoulos
65d0b8df9f Fix binary op dispatch (#584) 2024-01-29 19:36:17 -08:00
Awni Hannun
3c2f192345 Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Angelos Katharopoulos
37d98ba6ff No gil eval (#565) 2024-01-26 22:03:52 -08:00
Awni Hannun
8993382aaa Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
Awni Hannun
07f35c9d8a Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002 Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9 Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
David Koski
874b739f3c Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
taher
077c1ee64a QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471 [Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Angelos Katharopoulos
87b7fa9ba2 Bump the version (#554) 2024-01-25 11:01:05 -08:00
Danilo Peixoto
624065c074 Fix package installation for CI (#521)
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-25 09:43:34 -08:00
Awni Hannun
f27ec5e097 More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
2024-01-24 09:58:33 -08:00
Awni Hannun
f30e63353a Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Juarez Bochi
4fe2fa2a64 GGUF: Avoid dequantization when format is compatible (#426)
* GGUF: Don't dequantize q4_1

* Fix weight order. First in low bits

* Add unpacking for q4_0

* Don't dequantize q8_0

* rebase quants and split file

* don't quantize every weight

* reapply patch

* error handling

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:43:57 -08:00
Hazem Essam
37fc9db82c Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137 Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
86e0c79467 remove stale benchmarks (#527) 2024-01-22 22:17:58 -08:00
Awni Hannun
98c37d3a22 use axes in tensordot (#525) 2024-01-22 21:17:00 -08:00
Sugato Ray
f326dd8334 Update README.md (#524)
Add conda install option in docs.
2024-01-22 20:53:54 -08:00
Jagrit Digani
6d3bee3364 Fix oob reads in gemv kernel (#523) 2024-01-22 12:06:04 -08:00
Danilo Peixoto
ecb174ca9d Type annotations for mlx.core module (#512) 2024-01-21 12:53:12 -08:00
Awni Hannun
7a34e46677 Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Nripesh Niketan
92c22c1ea3 feat: Update isort version to 5.13.2 (#514) 2024-01-21 06:11:48 -08:00
Awni Hannun
d52383367a format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Awni Hannun
b207c2c86b Power VJP fix for 0 (#505) 2024-01-20 01:17:40 -08:00
Awni Hannun
6bf779e72b fix array from list for > 32 bit types (#501) 2024-01-19 15:49:25 -08:00
Juarez Bochi
ddf50113c5 GGUF: Load and save metadata (#446)
* gguf metadata
---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-19 14:06:05 -08:00
Arda Orçun
6589c869d6 Added MSE message (#500)
* Added MSE message

* changed wrong line.

* Update examples/python/linear_regression.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:27:50 -08:00
Anchen
f6feb61f92 feat: add support for saving safetensors in the save_weights (#497)
* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:19:33 -08:00
Awni Hannun
c4ec836523 fix isinf for integer types (#494) 2024-01-19 05:31:10 -08:00
AtomicVar
550d4bf7c0 Update binary_cross_entropy function to handle both logits and probabilities (#492) 2024-01-18 19:22:23 -08:00
Awni Hannun
f6e911ced0 version bump (#490)
* version bump

* Fix the dev version string

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-18 12:00:24 -08:00
Awni Hannun
3d99a8d31d Fix format / build (#489) 2024-01-18 10:01:59 -08:00
Ethan
a749a91c75 Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens

* Run format and add "cache_enabled" feature tests
2024-01-18 08:33:34 -08:00
toji
49a52610b7 Added formatter structure and a boolean value formatter (#354)
* added formatter structure and a boolean value formatter

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 07:49:41 -08:00
AtomicVar
d1fef34138 Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 06:44:44 -08:00
Angelos Katharopoulos
9c111f176d Fix split optimization for array iterator (#484) 2024-01-18 05:50:25 -08:00
Awni Hannun
78e5f2d17d usage doc for function transformations (#481) 2024-01-17 17:10:53 -08:00
Angelos Katharopoulos
90c234b7ac Fix round to round half-cases to even (#482) 2024-01-17 15:27:23 -08:00
Angelos Katharopoulos
135fd796d2 Fix detach for multi-output primitives (#480) 2024-01-17 14:08:07 -08:00
Jagrit Digani
78102a47ad Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Diogo
556cdf0e06 Resolves build issues with the extension example (#419)
* resolved extension build issues and added test to ci

* missing gguflib

* rebased

* force mlx install from fix branch

* linux build issue

* point to git install and comment out ci tests
2024-01-17 12:07:05 -08:00
Awni Hannun
275db7221a Command buffer reports errors (#479)
* command buffer reports errors

* typo

* simplify
2024-01-17 11:53:30 -08:00
AtomicVar
4a9012cba0 Sort some APIs docs by names (a-z) (#472) 2024-01-16 19:37:50 -08:00
Awni Hannun
a2bf7693dd Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-16 19:03:53 -08:00
Angelos Katharopoulos
d8fabaa12b Split multi output (#461)
* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
2024-01-16 13:33:55 -08:00
Avikant Srivastava
4e290d282f feat: add time based seed to random.h (#457)
* random seed from time

* fix: chrono

* refactor: snake case
2024-01-16 07:32:28 -08:00
Yashraj Singh
e72458a3fa implemented isposinf and isneginf in one PR (#470)
* ran precommit

* updated docs
2024-01-16 06:48:07 -08:00
Awni Hannun
a2ffea683a Fix eye for larger matrices (#463)
* fix eye
* fix scatter for <32bit (non native atomic) types
* fix int overflow
2024-01-16 00:51:24 -08:00
Angelos Katharopoulos
c15fe3e61b Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
2024-01-16 00:46:21 -08:00
Tristan Bilot
f44c132f4a Add scatter_min VJP (#462) 2024-01-16 00:37:40 -08:00
Matthew Ernst
92a2fdd577 Adds isinf (#445)
* adds isinf

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>

* use stream + nits

* typo

---------

Signed-off-by: matthewfernst <matthew.f.ernst@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-15 19:50:44 -08:00
Tristan Bilot
6022d4129e scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
2024-01-14 14:12:15 -08:00
Awni Hannun
4bc446be08 Use a dummy primitive to only sync with one output (#453)
* Use a dummy primitive to only sync with one output
* Fix test and choose stream with slight care
2024-01-14 14:09:40 -08:00
Awni Hannun
41cc7bdfdb Fix stub generation, change graph exporting for arrows to go to outputs (#455) 2024-01-14 14:06:16 -08:00
Awni Hannun
6e81c3e164 Sync only with outputs we need to sync with (#447) 2024-01-13 01:47:25 -08:00
Diogo
2e29d0815b Add tile op (#438) 2024-01-12 23:03:16 -08:00
Awni Hannun
1b71487e1f docs (#444) 2024-01-12 13:34:16 -08:00
Ayush Shridhar
1416e7b664 Add isnan (#423) 2024-01-12 11:16:48 -08:00
davidkoski
29081204d1 array.swapaxes should point to swapaxes free function (#441) 2024-01-12 11:06:16 -08:00
Angelos Katharopoulos
006d01ba42 Fix packaging of gguflib (#435) 2024-01-11 13:56:03 -08:00
Awni Hannun
46dc24d835 version bump (#433) 2024-01-11 12:29:35 -08:00
Awni Hannun
c9934fe8a4 Metal validation (#432)
* tests clear metal validation

* add cpp test with metal validation to circleci

* nit
2024-01-11 11:57:24 -08:00
Avikant Srivastava
975e265f74 feat: Add numpy constants (#428)
* add numpy constants

* feat: add unittests

* add newaxis

* add test for newaxis transformation

* refactor
2024-01-11 06:47:29 -08:00
Awni Hannun
c92a134b0d more docs (#421)
* more docs

* fix link

* nits + comments
2024-01-10 14:04:12 -08:00
Awni Hannun
3b4f066dac Correct types for vjp + tests (#418)
* correct types for vjp + tests

* fix build + comment
2024-01-10 13:32:37 -08:00
Juarez Bochi
b7f905787e GGUF support (#350)
* Initial GGUF support for tensor fields.

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-10 13:22:48 -08:00
Chunyang Wen
e3e933c6bc Add type hint for Module (#412) 2024-01-10 11:23:42 -08:00
Awni Hannun
1d90a76d63 in place ops behave in place, fix some overloads (#411) 2024-01-09 16:05:38 -08:00
Angelos Katharopoulos
961435a243 Scatter vjp (#394)
* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
2024-01-09 13:36:51 -08:00
Awni Hannun
e9ca65c939 Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape

* nit
2024-01-09 11:54:51 -08:00
Dwayne Robinson
753867123d Fix data_types.rst uint64 (#406)
uint64 correctly says 8 bytes, but the description is copy pasta.
2024-01-09 06:40:10 -08:00
Awni Hannun
f099ebe535 Multi output primitives (#330)
* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-08 16:39:08 -08:00
BigsnarfDude
f45f70f133 Update mlx-example link for llms llama in llama-inference.rst (#405) 2024-01-08 16:29:53 -08:00
YUN, Junwoo
0b8aeddac6 Additoinal losses (#336)
* cosine similarity loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>

* Docstring nits
2024-01-08 14:01:13 -08:00
Jagrit Digani
432ee5650b Update cpp tests with allclose and doctest::Approx for numerical tolerance (#401) 2024-01-08 09:35:05 -08:00
Nripesh Niketan
73321b8097 feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 07:00:05 -08:00
Hazem Essam
022a944367 Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function

* Ran pre-commit

* Ran pre commit

* Removed old sigmoid implementation to match with main

* Removed gated activation from __init__.py

* Removed unused test cases

* Removed unused imports

* format / docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 06:13:16 -08:00
Chris Costes
026ef9aae4 Update Install Instructions (#397)
* Add note to install instructions for building from source to ensure native arm64 environment and tools.

* Add troubleshooting info.

* remove cmake bits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-07 19:11:04 -08:00
Angelos Katharopoulos
a611b0bc82 Removes the retain_graph flag (#385)
* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
2024-01-07 15:16:51 -08:00
Diogo
449b43762e Add inner / outer op (#348)
* inner / outer impl

* python tests

* ops list and ack

* updated descriptions

* use test helper

* removed dtype check and flatten outer to 1-D

* updated docs

* just use the reshape to flatten
2024-01-07 09:01:09 -08:00
Angelos Katharopoulos
6ea6b4258d Fix style check (#395) 2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a Add theta cache for Rope and mask cache for ALiBi (#375) 2024-01-07 00:22:58 -08:00
Awni Hannun
c6d2878c1a safely divide for 0 size inputs (#388) 2024-01-07 00:19:54 -08:00
Awni Hannun
b34bf5d52b fix saving for non-contiguous arrays (#389) 2024-01-06 12:44:02 -08:00
Angelos Katharopoulos
608bd43604 Move the matmul type check in the op (#384) 2024-01-05 19:10:13 -08:00
Angelos Katharopoulos
4c48f6460d Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests

* Fix tf test
2024-01-05 18:17:44 -08:00
Daniel Strobusch
1331fa19f6 Make array conform to the Python Buffer Protocol (#323) 2024-01-05 15:58:33 -08:00
Daniel Strobusch
dfdb284e16 make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should

* have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`)
* behave identical when `dtype=None` or no dtype is passed

This important when passing kw args down from a numpy function like:

```
def f(x, dtype=None):
  mx.random.uniform(dtype=dtype)
  # ...
```

NumPy functions behave like this.

It also fixes a minor bug in `tri`: #378

Closes #378
2024-01-05 09:37:46 -08:00
mutexuan
d8f41a5c0f support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
2024-01-04 18:53:33 -08:00
Awni Hannun
b9e415d19c bump pre commit and fix format (#373) 2024-01-04 16:28:52 -08:00
davidkoski
c82a8cc526 move all ObjC (via metal-cpp) interaction until post static initializers (#370)
* move all ObjC (via metal-cpp) interaction until post static initializers

- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil

- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed

* Update device.cpp

remove commented code

* Update device.cpp

remove commented out code

* Update scheduler.h

update comment

* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu

* Update allocator.cpp

move pool to release/alloc area
2024-01-04 16:12:00 -08:00
Angelos Katharopoulos
75dc537e44 Fix the sigmoid module (#371) 2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5 revert copy (#366) 2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160 Remove useless pass (#364)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-04 06:34:01 -08:00
Awni Hannun
d752f8e142 Fix CI (#359)
* fix ci

* check for linux for fp16
2024-01-04 06:33:08 -08:00
toji
d2467c320d Added support for python copy (#335)
* Added support for python copy

* precommit changes

* removed `_compiled_call_impl` line

* added tests and suggested changes

* ACK changes
2024-01-03 20:59:40 -08:00
Diogo
0d31128a44 use union instead of | (#358) 2024-01-03 19:33:19 -08:00
Diogo
1ac18eac20 simple numpy helper for tests (#352) 2024-01-03 19:19:19 -08:00
Awni Hannun
526466dd09 version bump (#355)
* version bump

* one more
2024-01-03 14:48:24 -08:00
Angelos Katharopoulos
e7f5059fe4 Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Nripesh Niketan
d7ac050f4b feat: Add contributors graph to README (#332)
* Fix: typo in README.md

* feat: Add contributors graph to README

* Update acknowledgments and contributors
2024-01-03 13:03:11 -08:00
Gabrijel Boduljak
c7edafb729 implemented InstanceNorm (#244)
* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f Module checks the weight on load_weights (#337)
* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
2024-01-02 18:55:42 -08:00
Diogo
0782a4573a Add Tensordot op (#344) 2024-01-02 17:15:00 -08:00
Diogo
af66a09bde Adds issue template with common questions (#345)
* added template

* remove label
2024-01-02 16:52:20 -08:00
Angelos Katharopoulos
436bec9fd9 Fix the implementation of the Bilinear layer (#347) 2024-01-02 16:46:18 -08:00
Awni Hannun
99c80a2c8b Memory allocation (#292)
* try alternative gc

* try no cache

* add forced swap

* remove cache for now

* add cache back

* change fit crtieria

* remove unused function

* nit in comment

* tune / fix allocation

* increase block limit to original
2024-01-02 11:59:19 -08:00
Asaf Zorea
295ce9db09 Feature expand nn linear (#315)
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00
Josh Soref
44c1ce5e6a Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Chunyang Wen
144ecff849 Remove useless import (#340)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-01 19:25:49 -08:00
mutexuan
350095ce6e fix type cast error in item() for bfloat16 (#339)
Co-authored-by: xuan <xuan@apple.com>
2024-01-01 19:02:04 -08:00
Nripesh Niketan
e09bf35b28 feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers

* acknowledgement

* Add dropout tests to test_nn.py

* run pre-commit

* Add activation functions and dropout3d ops

* Add dropout tests for bfloat16 and float16
2023-12-31 14:01:21 -08:00
Daniel Strobusch
99c20f523e fix typos (#327) 2023-12-31 06:06:47 -08:00
Hazem Essam
e3b8da2a49 Added implementation for Scaled RoPE. (#261)
* Added scale for RoPE

* Ran pre-commit

* Added RoPE scaling test

* Added docstring for scale parameter

* Modified docstrings
2023-12-31 06:06:01 -08:00
Angelos Katharopoulos
a020a2d49d Improve repeat using broadcasting and reshape (#318) 2023-12-29 21:40:20 -08:00
Nripesh Niketan
930b159885 Fix: typo in README.md (#316) 2023-12-29 12:58:00 -08:00
Nripesh Niketan
5ad8fb7268 feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function

* run pre-commit

* Add Softsign activation function

* Add Softsign activation function

* Add documentation for ReLU6, Softplus, and Softsign activations

* Update activation functions in neural network layers

* Add LogSoftmax and Hardswish activations

* run pre-commit

* Update activations.py

* Added acknowledgements

* Fix activation function comments

* Fix activation functions in neural network layers
2023-12-29 11:49:36 -08:00
Chunyang Wen
2aedf3e791 Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten

* Remove the if statement

---------

Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 20:55:10 -08:00
Chunyang Wen
473b6b43b4 Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 14:46:13 -08:00
Angelos Katharopoulos
d29770eeaa Update batchnorm to have the running stats in parameters (#305) 2023-12-28 14:31:10 -08:00
Chunyang Wen
040c3bafab Add missing f str (#306)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 06:09:34 -08:00
Chunyang Wen
05767b026f Add information for dropout probability (#304)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-27 21:51:30 -08:00
Diogo
a83d5d60bd Addition in acknowledgements (#302) 2023-12-27 13:46:47 -08:00
Bahaa
ff2b58e299 Add support for repeat (#278)
* add repeat function

* fix styling

* optimizing repeat

* fixed minor issues

* not sure why that folder is there xD

* fixed now for sure

* test repeat not repeat test

* Fixed

---------

Co-authored-by: Bahaa Eddin tabbakha <bahaa@Bahaas-MacBook-Pro.local>
2023-12-27 13:11:38 -08:00
YUN, Junwoo
4417e37ede Transformer fix (#167)
* add transformer with dropout, fix transformer ffm, layernorm order

* precommit changes

* precommit changes

* add docstring, activation, norm_first

* run precommit

* run precommit

* add doctstring

* precommit

* style nits in docs

---------

Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 08:48:36 -08:00
Angelos Katharopoulos
79c95b6919 Fix load compilation (#298) 2023-12-27 06:20:45 -08:00
Diogo
1f6ab6a556 Safetensor support (#215)
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 02:06:55 -08:00
Gabrijel Boduljak
6b0d30bb85 linalg.norm (#187)
* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-26 19:42:04 -08:00
Angelos Katharopoulos
447bc089b9 Fix tolerance in de-/quantization test (#295) 2023-12-26 19:21:05 -08:00
Yutaka Kondo
fc4e5b476b Fix llama link in README.md (#289) 2023-12-25 20:53:20 -08:00
Daniel Strobusch
d58ac083f3 expose itemsize and nbytes as for numpy arrays (#284)
see:
  * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.nbytes.html
  * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.itemsize.html

relates to https://github.com/ml-explore/mlx-examples/pull/174
2023-12-25 10:34:28 -08:00
__mo_san__
a123c3c7d2 implement-batch-norm-layer (#217)
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-25 07:32:53 -08:00
Angelos Katharopoulos
9e6b8c9f48 Refactor the reduction kernels (#277) 2023-12-24 14:47:57 -08:00
Zach Schillaci
22fee5a383 Remove redundant assert in losses.py (#281) 2023-12-24 08:39:08 -08:00
Daniel Strobusch
7365d142a3 random.uniform must respect dtype, even if lower precision than "low" (#280)
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Awni Hannun
8b227fa9af fix no metal build (#276) 2023-12-23 19:18:10 -08:00
Vidit Agarwal
8c3da54c7d Fix failing test for log cosh loss (#275)
* fix assert statement in log_cosh_loss

* reformatted by pre-commit black
2023-12-23 16:26:46 -08:00
Vidit Agarwal
acf1721b98 Corrected the example of value_and_grad (#274)
* Corrected the example for mx.value_and_grad

* Reformat through pre-commit/black
2023-12-23 11:06:38 -08:00
Finn Voorhees
f91f450141 Fix argmax returns documentation (#263) 2023-12-22 20:33:17 -08:00
Ronan Collobert
cd3616a463 Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device

* make things simpler

* no metal backend support

* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Nicholas Santavas
d35fa1db41 Add Hinge, Huber and LogCosh losses (#199) 2023-12-22 10:28:10 -08:00
Justin Deschenaux
e8deca84e0 Add dropout2d (#250) 2023-12-22 08:02:29 -08:00
Angelos Katharopoulos
8385f93cea Bumping the version (#256) 2023-12-21 18:33:14 -08:00
Awni Hannun
2118c3dbfa fix (#255) 2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52 A temporary fix (#254) 2023-12-21 17:59:15 -08:00
Angelos Katharopoulos
1d053e0d1d Fix the alibi test that was left unchanged (#252) 2023-12-21 14:59:25 -08:00
Hazem Essam
0aa65c7a6b Added ALiBi implementation (#232) 2023-12-21 14:36:38 -08:00
Daniel Strobusch
794feb83df support arange for bfloat16 (#245) 2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
2c7df6795e Make sure that arrays are freed when saving (#247) 2023-12-21 14:08:24 -08:00
Angelos Katharopoulos
b3916cbf2b Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8 Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Justin Deschenaux
4912ff3ec2 Add Lion optimizer (#209)
* Add Lion optimizer
* Update acknowledgements also with past contributions
2023-12-20 13:54:58 -08:00
Awni Hannun
f40d17047d Indexing bug (#233)
* fix

* test
2023-12-20 10:44:01 -08:00
Angelos Katharopoulos
2807c6aff0 Implements divide for integer types and adds floor_divide op (#228)
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c fix for non-macos build issue on cblas.h (#227) 2023-12-19 17:01:59 -08:00
davidkoski
37024d899c fixes for building with swiftpm (#225)
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Diogo
137f55bf28 fail early if readinto does not exist (#221) 2023-12-19 13:27:17 -08:00
Emircan Erol
e549f84532 Triplet Loss (#211)
* Triplet Loss

* Requested Changes

* Margin to alpha
2023-12-19 12:37:12 -08:00
Angelos Katharopoulos
dfa9f4bc58 An initial quantized matmul implementation (#205)
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149 Added linspace (#181)
* linspace ops support

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Juarez Bochi
f4f6e17d45 Fix cross-attention (#210)
* Fix cross-attention

With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer

* Add name to contributors
2023-12-18 12:27:27 -08:00
Angelos Katharopoulos
4d4af12c6f Adds round op and primitive (#203) 2023-12-18 11:32:48 -08:00
Awni Hannun
477397bc98 Citation + Contributor acknowledgment section (#207)
* cite

* nits

* nits

* comment
2023-12-18 10:07:00 -08:00
jojopuppet
18cca64c81 Add smoothed L1 loss and enhancements to cross entropy loss (#166)
* Add smooth_l1_loss
* Add labels moothing for cross entropy loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 07:26:21 -08:00
Awni Hannun
0e5807bbcb include optional (#202) 2023-12-17 22:01:35 -08:00
Cyril Zakka, MD
8eb56beb3a Added clip function (#159)
* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
ee0c2835c5 Docs updates (#198)
Reorganize NN docs + a few other tidbits.
2023-12-17 13:20:55 -08:00
Awni Hannun
90d04072b7 fix build w/ flatten (#195) 2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52 implemented Flatten Module (#149)
* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
YUN, Junwoo
eebd7c275d Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes (#142)
* Add AdaMax, AdaDelta, RMSprop
2023-12-16 21:43:15 -08:00
Austin Liu
a67bbfe745 Update docs (#177) (#190)
* update docs (fix #177)

* reorder

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 06:52:18 -08:00
Awni Hannun
104c34f906 setite negative indexing bug (#189) 2023-12-16 06:44:47 -08:00
Diogo
dc2edc762c added tri / tril / triu (#170)
* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-15 17:30:34 -08:00
Awni Hannun
2e02acdc83 add base kwarg to rope (#186) 2023-12-15 16:47:59 -08:00
Ronan Collobert
83f266c44c Lazy metal_device_ initialization (#185)
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Víctor Aguilar
f24200db2c accross -> across (#183) 2023-12-15 13:46:50 -08:00
Jason
e28b57e371 Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
2023-12-14 13:21:19 -08:00
Awni Hannun
e5851e52b1 Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
2023-12-14 12:59:12 -08:00
Diogo
f55908bc48 Added stubs for python files generated from C++ (#136)
* added pybind11-stubgen

* docs for generating stubs

* added line to readme
2023-12-14 12:58:45 -08:00
Luca Arnaboldi
b93c4cf378 Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Stv.X
1e0c78b970 Fixed typo in some proprietary terms. (#161) 2023-12-13 19:48:00 -08:00
Awni Hannun
76e1af0e02 bump version (#157) 2023-12-13 14:28:26 -08:00
Ikko Eltociear Ashimine
c3272d4917 Update conv.cpp (#145)
Peform -> Perform
2023-12-12 11:27:49 -08:00
SputNikPlop
50f5d14b11 fix: tidy pull request template (#143)
* fix: tidy pull request template

* fix: feedback from awni
2023-12-12 08:14:39 -08:00
noahsmartin
d14a0e4ff9 Docs update (#144) 2023-12-12 07:53:42 -08:00
Diogo
fb675de30d Run lint check for prs (#139) 2023-12-12 00:23:33 -08:00
Awni Hannun
25f70d4ca4 Fix divide types + floor divide (//) (#138)
* divide types

* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0 Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
2023-12-11 17:04:07 -08:00
Awni Hannun
b9226c367c Fix CI format + build issue (#137)
* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2023-12-11 15:01:41 -08:00
Angelos Katharopoulos
3214629601 Mlx array accessor (#128)
* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
2023-12-11 13:42:55 -08:00
__mo_san__
072044e28f fix and update binary cross entropy loss tests (#133)
* fix conflicts

* updated tests
2023-12-11 12:42:17 -08:00
Cyril Zakka, MD
e080290ba4 Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
69505b4e9b fixes (#131) 2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44 Add Binary Cross Entropy loss (#122)
* update BCE added tests for it ...

* added binary cross entropy loss to docs

* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Yiyang(Steven) Yu
0cfbfc9904 Update README.md (#121) 2023-12-10 14:47:37 -08:00
Awni Hannun
2d0130f80f fix loss tests (#118)
* fix loss tests

* use none as default
2023-12-10 10:08:19 -08:00
__mo_san__
c1e1c1443f Added Adagrad optimizer (#102) 2023-12-10 09:22:39 -08:00
Henry Ansah
68bf1d7867 add nn module for sigmoid activation (#111)
* add nn module for sigmoid activation

* update .gitignore with .cache folder generated by jetbrains fleet ide

* remove .cache folder
2023-12-10 07:00:39 -08:00
Angelos Katharopoulos
600db7d754 Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
__mo_san__
ef7b8756c0 Add tanh activation function (#115)
* added Adagrad optimizer ...

* added Tanh activation function ...

* reformatted file ...

* remove unrelated stuff ...

* Update activations.py
2023-12-09 19:25:38 -08:00
Enoch Kan
0b28399638 added mse_loss, nll_loss and kl_div_loss (#98)
* added mse_loss, nll_loss and kl_div_loss

* fixed axis not defined error in nll_loss

* fixed axis not defined in kl_div_loss

* added tests for mse, nll and kl_div

* modified docstrings and added reduce helper func

* updated docstring in kl_div_loss and moved helper func

* added new kl divergence implementation

* added reduction to test

* updated docstring of kl_div_loss with correct spelling

* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Joe Barrow
ac6dc5d3eb Adding optional bias param to MultiHeadAttention (#104)
* Adding optional  param to

* Run style-checker
2023-12-09 11:04:28 -08:00
Awni Hannun
89b90dcfec Pr template (#99)
* pr template
* format fix
2023-12-09 09:36:56 -08:00
Angelos Katharopoulos
fd836d891b Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
2023-12-09 09:35:28 -08:00
AtomicVar
976e8babbe Use compiled black as the pre-commit formatter (#94) 2023-12-09 07:06:46 -08:00
Awni Hannun
2520dbcf0a add losses to the docs, fix black failur (#92) 2023-12-09 06:06:52 -08:00
Abe Leininger
430bfb4944 Adds Nesterov momentum to SGD (#87) 2023-12-08 23:23:36 -08:00
ShiJZ
08d51bf232 Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
2023-12-08 21:39:08 -08:00
Kai Ma
cb9e585b8e Style fix for loss functions (#91)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none

* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484 MLE and L1 loss functions (#88)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none
2023-12-08 20:21:37 -08:00
Angelos Katharopoulos
2b714714e1 Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Joe Barrow
69a24e6a1e AdamW implementation (#72)
* AdamW implementation without bias correction
* Makes use of the underlying Adam implementation
2023-12-08 14:45:34 -08:00
Zach Schillaci
5b9be57ac3 Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
Jagrit Digani
e89c571de7 Update cmake to detect and throw warnings if not on a arm system (#81)
* Update cmake to detect and throw warnings if not on a arm system
2023-12-08 11:03:25 -08:00
Angelos Katharopoulos
209404239b Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Awni Hannun
4e3bdb560c random generation fix (#80)
Random generation fix
2023-12-08 10:40:57 -08:00
Gautam krishna R
86b614afcd added long_description for pypi readme (#69) 2023-12-08 03:33:29 -08:00
Awni Hannun
cfc39d84b7 Some docs on unified memory (#62)
* doc on unified memory
2023-12-07 19:42:24 -08:00
Zach Schillaci
d11d77e581 Spelling fixes in transformer.py (#59) 2023-12-07 13:32:09 -08:00
Jagrit Digani
bf410cb85e Update CMake to not try and build metallib if Metal framework not found (#55) 2023-12-07 09:48:42 -08:00
rushyam
2e126aeb7e Feature Addition: Encoder-Decoder Transformer Architecture (#50)
* Implemented decoder-transformer-layer, decoder-transformer  and introduce encoder-decoder transformer

* added relu layer

* add src, tgt, memory mask

---------

Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
2023-12-07 07:37:36 -08:00
Awni Hannun
dfbc52ce56 Install docs + python versions (#53)
* install + python versions

* add link in install docs

* add link
2023-12-07 07:29:17 -08:00
275 changed files with 31408 additions and 5910 deletions

View File

@@ -1,5 +1,8 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -7,6 +10,9 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
linux_build_and_test:
@@ -26,18 +32,28 @@ jobs:
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run:
name: Build python package
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Run the python tests
name: Generate package stubs
command: |
python3 -m unittest discover python/tests
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
@@ -47,154 +63,180 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
machine: true
resource_class: ml-explore/m-builder
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Build python package
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
- run:
name: Run the python tests
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
source env/bin/activate
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
xcode_version:
type: string
default: "14"
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
brew install python@<< parameters.python_version >>
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
- run:
name: Build pacakge
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
build_linux_test_release:
parameters:
python_version:
type: string
default: "3.9"
macos_version:
extra_env:
type: string
default: "14"
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Install dependencies
name: Build wheel
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
- run:
name: Build pacakge
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
pip install auditwheel
pip install patchelf
pip install build
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build pacakge
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
pip install . -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- store_artifacts:
path: dist/
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- linux_build_and_test
- mac_build_and_test
- linux_build_and_test
- build_release:
filters:
tags:
@@ -204,20 +246,53 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
when: << pipeline.parameters.nightly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_package:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
weekly_build:
when: << pipeline.parameters.weekly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_dev_release:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

28
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,28 @@
---
name: Bug report
about: Create a report about an issue you've encountered
title: "[BUG] "
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Include code snippet
```python
```
**Expected behavior**
A clear and concise description of what you expected to happen.
**Desktop (please complete the following information):**
- OS Version: [e.g. MacOS 14.1.2]
- Version [e.g. 0.7.0]
**Additional context**
Add any other context about the problem here.

12
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,12 @@
## Proposed changes
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
## Checklist
Put an `x` in the boxes that apply.
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)

20
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

8
.gitignore vendored
View File

@@ -6,11 +6,16 @@ __pycache__/
# C extensions
*.so
# tensor files
*.safe
*.safetensors
# Metal libraries
*.metallib
venv/
# Distribution / packaging
python/mlx/core
python/mlx/share
python/mlx/include
.Python
@@ -74,3 +79,6 @@ build/
# VSCode
.vscode/
.DS_Store
# Jetbrains
.cache

View File

@@ -1,9 +1,16 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v14.0.6
rev: v17.0.6
hooks:
- id: clang-format
- repo: https://github.com/psf/black
rev: 22.10.0
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args:
- --profile=black

View File

@@ -1,3 +1,24 @@
# Individual Contributors
If you wish to be acknowledged for your contributions, please list your name
with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>
# Third-Party Software
MLX leverages several third-party software, listed here together with
their license copied verbatim.
@@ -231,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.24)
project(mlx LANGUAGES CXX)
project(mlx LANGUAGES C CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -18,7 +18,34 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.3)
set(MLX_VERSION 0.3.0)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
else()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
# ----------------------------- Lib -----------------------------
@@ -37,15 +64,18 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION)
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
message(STATUS "Building with SDK for MacOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
@@ -53,7 +83,7 @@ elseif (MLX_BUILD_METAL)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires MacOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
@@ -75,13 +105,13 @@ elseif (MLX_BUILD_METAL)
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (ACCELERATE_LIBRARY)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate not found, using default backend.")
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
@@ -93,16 +123,27 @@ else()
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
@@ -128,6 +169,8 @@ if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
@@ -197,4 +240,4 @@ install(
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
)

View File

@@ -1,3 +1,4 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,8 +6,8 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:
@@ -16,24 +16,24 @@ Some key features of MLX include:
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Composable function transformations**: MLX has composable function
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are built
- **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices
(currently, the CPU and GPU).
(currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported
device types without moving data.
device types without transferring data.
MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient
@@ -53,7 +53,7 @@ variety of examples, including:
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
- Large-scale text generation with
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
@@ -61,17 +61,25 @@ variety of examples, including:
## Quickstart
See the [quick start
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
in the documentation.
## Installation
MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
@@ -79,4 +87,28 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
on contributing to MLX.
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.
## Citing MLX
The MLX software suite was initially developed with equal contribution by Awni
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following
BibTex entry:
```
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
```

View File

@@ -233,6 +233,20 @@ void time_gather_scatter() {
TIME(single_element_add);
}
void time_divmod() {
auto a = random::normal({1000});
auto b = random::normal({1000});
eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused);
auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
};
TIME(divmod_separate);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops();
@@ -246,4 +260,5 @@ int main() {
time_matmul();
time_reductions();
time_gather_scatter();
time_divmod();
}

View File

@@ -1,7 +1,6 @@
# Copyright © 2023 Apple Inc.
import numpy as np
from time_utils import time_fn

View File

@@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
import mlx.core as mx
from time_utils import time_fn
B = 8

View File

@@ -1,13 +1,14 @@
# Copyright © 2023 Apple Inc.
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
@@ -165,13 +166,13 @@ if __name__ == "__main__":
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:

View File

@@ -1,14 +1,14 @@
# Copyright © 2023 Apple Inc.
import matplotlib.pyplot as plt
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import subprocess
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
results_dir = "./results"
@@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
@@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []

View File

@@ -4,8 +4,10 @@ import argparse
import math
import os
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
def int_or_list(x):
@@ -22,6 +24,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -48,6 +60,63 @@ def matmul(x, y):
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
}
def conv1d(x, y):
ys = []
for i in range(10):
@@ -95,7 +164,77 @@ def softmax_fused(axis, x):
def relu(x):
y = x
for i in range(100):
y = mx.maximum(y, 0)
y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
mx.eval(y)
def leaky_relu(x):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def elu(x):
y = x
for i in range(100):
y = nn.elu(y)
mx.eval(y)
def relu6(x):
y = x
for i in range(100):
y = nn.relu6(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def celu(x):
y = x
for i in range(100):
y = nn.celu(y)
mx.eval(y)
def log_sigmoid(x):
y = x
for i in range(100):
y = nn.log_sigmoid(y)
mx.eval(y)
@@ -130,6 +269,13 @@ def linear(w, b, x):
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
@@ -180,6 +326,20 @@ def topk(axis, x):
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@@ -211,9 +371,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument(
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -230,11 +388,15 @@ if __name__ == "__main__":
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
args.dtype
]
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@@ -250,8 +412,14 @@ if __name__ == "__main__":
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
@@ -277,6 +445,26 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
@@ -311,5 +499,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")

View File

@@ -22,6 +22,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return torch.float32
else:
dt = getattr(torch, x)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{x} is not a torch dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@@ -115,6 +125,70 @@ def relu(x):
sync_if_needed(x)
@torch.no_grad()
def leaky_relu(x):
y = x
for i in range(100):
y = torch.nn.functional.leaky_relu(y)
sync_if_needed(x)
@torch.no_grad()
def elu(x):
y = x
for i in range(100):
y = torch.nn.functional.elu(y)
sync_if_needed(x)
@torch.no_grad()
def celu(x):
y = x
for i in range(100):
y = torch.nn.functional.celu(y)
sync_if_needed(x)
@torch.no_grad()
def relu6(x):
y = x
for i in range(100):
y = torch.nn.functional.relu6(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
for i in range(100):
y = torch.nn.functional.softplus(y)
sync_if_needed(x)
@torch.no_grad()
def log_sigmoid(x):
y = x
for i in range(100):
y = torch.nn.functional.logsigmoid(y)
sync_if_needed(x)
@torch.no_grad()
def prelu(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
sync_if_needed(x)
@torch.no_grad()
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
return torch.nn.functional.mish(y)
sync_if_needed(x)
@torch.no_grad()
def scalar_mult(x):
y = x
@@ -209,6 +283,14 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
for i in range(100):
y = torch.nn.functional.selu(y)
sync_if_needed(x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
@@ -240,7 +322,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@@ -255,9 +337,15 @@ if __name__ == "__main__":
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
types = args.dtype
if not types:
types = [torch.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@@ -302,6 +390,28 @@ if __name__ == "__main__":
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "mish":
print(bench(mish, x))
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))

View File

@@ -62,7 +62,7 @@ def make_predicate(positive_filter, negative_filter):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
parser.add_argument(
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
)
@@ -80,10 +80,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
if _filter(x)
else None
)
@@ -125,6 +123,14 @@ if __name__ == "__main__":
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
compare_filtered("argmax --size 10x1024x128 --axis 1")
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
@@ -193,6 +199,27 @@ if __name__ == "__main__":
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
compare_filtered("relu --size 32x16x1024")
compare_filtered("relu --size 32x16x1024 --cpu")
compare_filtered("leaky_relu --size 32x16x1024")
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
compare_filtered("elu --size 32x16x1024")
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")
compare_filtered("celu --size 32x16x1024 --cpu")
compare_filtered("log_sigmoid --size 32x16x1024")
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
compare_filtered("step --size 32x16x1024")
compare_filtered("step --size 32x16x1024 --cpu")
compare_filtered("selu --size 32x16x1024")
compare_filtered("selu --size 32x16x1024 --cpu")
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
compare_filtered("mish --size 32x16x1024 --cpu")
compare_filtered("prelu --size 32x16x1024")
compare_filtered("prelu --size 32x16x1024 --cpu")
compare_filtered("scalar_mul --size 32x16x1024")
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
compare_filtered("cross_entropy --size 256x1024")

View File

@@ -0,0 +1,53 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -1,198 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,118 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,199 +0,0 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.nn as nn
import torch.mps
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -0,0 +1,56 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
def scatter(dst, x, idx):
dst[idx] = x
mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
def gather(dst, x, idx, device):
dst[idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -1,8 +1,8 @@
# Copyright © 2023 Apple Inc.
import argparse
import mlx.core as mx
import mlx.core as mx
from time_utils import time_fn
@@ -44,6 +44,13 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -101,6 +108,7 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

View File

@@ -12,7 +12,7 @@ include(CMakeParseArguments)
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of depedency files (like headers)
# DEPS: List of dependency files (like headers)
#
macro(mlx_build_metallib)
# Parse args
@@ -32,7 +32,7 @@ macro(mlx_build_metallib)
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metllib build command
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal

2
docs/.gitignore vendored
View File

@@ -1 +1,3 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@@ -26,7 +26,7 @@ python -m http.server <port>
and point your browser to `http://localhost:<port>`.
### Push to Github Pages
### Push to GitHub Pages
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
the docs. Then force add the `build/html` directory:

View File

@@ -0,0 +1,33 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. add toctree option to make autodoc generate the pages
.. autoclass:: {{ objname }}
{% block attributes %}
{% if attributes %}
.. rubric:: Attributes
.. autosummary::
:toctree: .
{% for item in attributes %}
~{{ fullname }}.{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
{% block methods %}
{% if methods %}
.. rubric:: Methods
.. autosummary::
:toctree: .
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -1,19 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@@ -5,13 +5,15 @@
import os
import subprocess
import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = "0.0.4"
release = "0.0.4"
version = ".".join(mx.__version__.split(".")[:3])
release = version
# -- General configuration ---------------------------------------------------
@@ -24,6 +26,7 @@ extensions = [
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,

View File

@@ -15,7 +15,7 @@ Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
@@ -69,7 +69,7 @@ C++ API:
.. code-block:: C++
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -150,7 +150,7 @@ back and go to our example to give ourselves a more concrete image.
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself accross
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
@@ -230,7 +230,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
This operation now handles the following:
#. Upcast inputs and resolve the the output data type.
#. Upcast inputs and resolve the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
@@ -284,14 +284,14 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
@@ -305,7 +305,7 @@ if we encounter an unexpected type.
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -485,7 +485,7 @@ each data type.
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
@@ -537,7 +537,7 @@ below.
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@@ -568,7 +568,7 @@ below.
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@@ -581,7 +581,7 @@ to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
@@ -601,8 +601,8 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@@ -642,7 +642,7 @@ own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitve along given axis */
/** Vectorize primitive along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -666,7 +666,7 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
@@ -697,7 +697,7 @@ are already provided, adding our :meth:`axpby` becomes very simple!
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
@@ -840,7 +840,7 @@ This will result in a directory structure as follows:
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same strucutre as
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
@@ -927,18 +927,18 @@ Results:
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
The full example code is available in `mlx <code>`_.
.. code: `TODO_LINK/extensions`_
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc

View File

@@ -371,7 +371,7 @@ Scripts
The full example code is available in `mlx-examples`_.
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llama
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
Roformer: Enhanced transformer with rotary position embedding. arXiv

View File

@@ -61,7 +61,10 @@ set:
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as `mnist`.
.. code-block:: python

View File

@@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
The design of MLX is inspired by frameworks like `PyTorch
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
memory. Operations on MLX arrays can be performed on any of the supported
device types without performing data copies. Currently supported device types
@@ -35,8 +35,15 @@ are the CPU and GPU.
:caption: Usage
:maxdepth: 1
quick_start
using_streams
usage/quick_start
usage/lazy_evaluation
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams
.. toctree::
:caption: Examples
@@ -56,6 +63,7 @@ are the CPU and GPU.
python/random
python/transforms
python/fft
python/linalg
python/nn
python/optimizers
python/tree_utils

View File

@@ -1,8 +1,8 @@
Build and Install
=================
Install from PyPI
-----------------
Python Installation
-------------------
MLX is available on PyPI. All you have to do to use MLX with your own Apple
silicon computer is
@@ -11,9 +11,40 @@ silicon computer is
pip install mlx
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- macOS >= 13.3
.. note::
MLX is only available on devices running MacOS >= 13.3
It is highly recommended to use MacOS 14 (Sonoma)
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
MLX is also available on conda-forge. To install MLX with conda do:
.. code-block:: shell
conda install conda-forge::mlx
Troubleshooting
^^^^^^^^^^^^^^^
*My OS and Python versions are in the required range but pip still does not find
a matching distribution.*
Probably you are using a non-native Python. The output of
.. code-block:: shell
python -c "import platform; print(platform.processor())"
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
are using a non-native Python. Switch your Python to a native Python. A good
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
Build from source
-----------------
@@ -23,8 +54,11 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above)
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
Python API
^^^^^^^^^^
@@ -60,9 +94,17 @@ For developing use an editable install:
To make sure the install is working run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
^^^^^^^
@@ -129,8 +171,64 @@ should point to the path to the built metal library.
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
MacOS SDK will be used
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
You see the following error when you try to build:
.. code-block:: shell
error: unable to find utility "metal", not a developer tool or in PATH
To fix this, first make sure you have Xcode installed:
.. code-block:: shell
xcode-select --install
Then set the active developer directory:
.. code-block:: shell
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
~~~~~~~~~
.. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
terminal.
Verify the terminal is now running natively the following command:
.. code-block:: shell
$ uname -p
arm
Also check that cmake is using the correct architecture:
.. code-block:: shell
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again.

View File

@@ -34,6 +34,7 @@ Array
array.prod
array.reciprocal
array.reshape
array.round
array.rsqrt
array.sin
array.split

View File

@@ -29,9 +29,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint32``
* - ``uint64``
- 8
- 32-bit unsigned integer
- 64-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer

View File

@@ -9,9 +9,10 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream

View File

@@ -0,0 +1,12 @@
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm
qr

View File

@@ -64,7 +64,6 @@ Quick Start with Neural Networks
# gradient with respect to `mlp.trainable_parameters()`
loss_and_grad = nn.value_and_grad(mlp, l2_loss)
.. _module_class:
The Module Class
@@ -86,20 +85,58 @@ name should not start with ``_``). It can be arbitrarily nested in other
:meth:`Module.parameters` can be used to extract a nested dictionary with all
the parameters of a module and its submodules.
A :class:`Module` can also keep track of "frozen" parameters.
:meth:`Module.trainable_parameters` returns only the subset of
:meth:`Module.parameters` that is not frozen. When using
:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these
trainable parameters.
A :class:`Module` can also keep track of "frozen" parameters. See the
:meth:`Module.freeze` method for more details. :meth:`mlx.nn.value_and_grad`
the gradients returned will be with respect to these trainable parameters.
Updating the parameters
Updating the Parameters
^^^^^^^^^^^^^^^^^^^^^^^
MLX modules allow accessing and updating individual parameters. However, most
times we need to update large subsets of a module's parameters. This action is
performed by :meth:`Module.update`.
performed by :meth:`Module.update`.
Value and grad
Inspecting Modules
^^^^^^^^^^^^^^^^^^
The simplest way to see the model architecture is to print it. Following along with
the above example, you can print the ``MLP`` with:
.. code-block:: python
print(mlp)
This will display:
.. code-block:: shell
MLP(
(layers.0): Linear(input_dims=2, output_dims=128, bias=True)
(layers.1): Linear(input_dims=128, output_dims=128, bias=True)
(layers.2): Linear(input_dims=128, output_dims=10, bias=True)
)
To get more detailed information on the arrays in a :class:`Module` you can use
:func:`mlx.utils.tree_map` on the parameters. For example, to see the shapes of
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
Value and Grad
--------------
Using a :class:`Module` does not preclude using MLX's high order function
@@ -137,36 +174,10 @@ In detail:
value_and_grad
Neural Network Layers
---------------------
.. toctree::
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
Embedding
ReLU
GELU
SiLU
Linear
Conv1d
Conv2d
LayerNorm
RMSNorm
GroupNorm
RoPE
MultiHeadAttention
Sequential
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
gelu
gelu_approx
gelu_fast_approx
relu
silu
nn/module
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -0,0 +1,24 @@
.. _nn_functions:
.. currentmodule:: mlx.nn
Functions
---------
Layers without parameters (e.g. activation functions) are also provided as
simple functions.
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
gelu
gelu_approx
gelu_fast_approx
mish
prelu
relu
selu
softshrink
silu
step

View File

@@ -0,0 +1,45 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -0,0 +1,42 @@
.. _layers:
.. currentmodule:: mlx.nn
Layers
------
.. autosummary::
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
BatchNorm
Conv1d
Conv2d
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
InstanceNorm
LayerNorm
Linear
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU
QuantizedLinear
RMSNorm
ReLU
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer

View File

@@ -0,0 +1,25 @@
.. _losses:
.. currentmodule:: mlx.nn.losses
Loss Functions
--------------
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss

View File

@@ -1,7 +1,37 @@
mlx.nn.Module
=============
Module
======
.. currentmodule:: mlx.nn
.. autoclass:: Module
:members:
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Module.apply
Module.apply_to_modules
Module.children
Module.eval
Module.filter_and_map
Module.freeze
Module.leaf_modules
Module.load_weights
Module.modules
Module.named_modules
Module.parameters
Module.save_weights
Module.train
Module.trainable_parameters
Module.unfreeze
Module.update
Module.update_modules

View File

@@ -26,23 +26,40 @@ Operations
argsort
array_equal
broadcast_to
ceil
clip
concatenate
convolve
conv1d
conv2d
cos
cosh
dequantize
diag
diagonal
divide
divmod
equal
erf
erfinv
exp
expand_dims
eye
flatten
floor
floor_divide
full
greater
greater_equal
identity
inner
isnan
isposinf
isneginf
isinf
less
less_equal
linspace
load
log
log2
@@ -50,6 +67,8 @@ Operations
log1p
logaddexp
logical_not
logical_and
logical_or
logsumexp
matmul
max
@@ -57,19 +76,27 @@ Operations
mean
min
minimum
moveaxis
multiply
negative
ones
ones_like
outer
partition
pad
prod
quantize
quantized_matmul
reciprocal
repeat
reshape
round
rsqrt
save
savez
savez_compressed
save_gguf
save_safetensors
sigmoid
sign
sin
@@ -80,14 +107,20 @@ Operations
sqrt
square
squeeze
stack
stop_gradient
subtract
sum
swapaxes
take
take_along_axis
tan
tanh
tensordot
transpose
tri
tril
triu
var
where
zeros

View File

@@ -29,13 +29,8 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
.. toctree::
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
Adam
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers

View File

@@ -0,0 +1,20 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -0,0 +1,13 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay

View File

@@ -33,13 +33,13 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
.. autosummary::
:toctree: _autosummary
seed
key
split
bernoulli
categorical
gumbel
key
normal
randint
uniform
seed
split
truncated_normal
uniform

View File

@@ -9,6 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp

430
docs/src/usage/compile.rst Normal file
View File

@@ -0,0 +1,430 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -0,0 +1,191 @@
.. _function_transforms:
Function Transforms
===================
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
Here is a simple example:
.. code-block:: shell
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
.. code-block:: shell
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
Automatic Differentiation
-------------------------
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
.. note::
If you are coming to MLX from PyTorch, you no longer need functions like
``backward``, ``zero_grad``, and ``detach``, or properties like
``requires_grad``.
The most basic example is taking the gradient of a scalar-valued function as we
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
compute gradients of more complex functions. By default these functions compute
the gradient with respect to the first argument:
.. code-block:: python
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
One way to get the loss and gradient is to call ``loss_fn`` followed by
``grad_fn``, but this can result in a lot of redundant work. Instead, you
should use :func:`value_and_grad`. Continuing the above example:
.. code-block:: python
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
You can also take the gradient with respect to arbitrarily nested Python
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
:obj:`dict`).
Suppose we wanted a weight and a bias parameter in the above example. A nice
way to do that is the following:
.. code-block:: python
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
Automatic Vectorization
-----------------------
.. _vmap:
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
through a basic and contrived example for the sake of clarity, but :func:`vmap`
can be quite powerful for more complex functions which are difficult to optimize
by hand.
.. warning::
Some operations are not yet supported with :func:`vmap`. If you encounter an error
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
We will prioritize including it.
A naive way to add the elements from two sets of vectors is with a loop:
.. code-block:: python
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
Let's time these two different versions:
.. code-block:: python
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

123
docs/src/usage/indexing.rst Normal file
View File

@@ -0,0 +1,123 @@
.. _indexing:
Indexing Arrays
===============
.. currentmodule:: mlx.core
For the most part, indexing an MLX :obj:`array` works the same as indexing a
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
how that works.
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> arr[3]
array(3, dtype=int32)
>>> arr[-2] # negative indexing works
array(8, dtype=int32)
>>> arr[2:8:2] # start, stop, stride
array([2, 4, 6], dtype=int32)
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
.. code-block:: shell
>>> arr = mx.arange(8).reshape(2, 2, 2)
>>> arr[:, :, 0]
array(3, dtype=int32)
array([[0, 2],
[4, 6]], dtype=int32
>>> arr[..., 0]
array([[0, 2],
[4, 6]], dtype=int32
You can index with ``None`` to create a new axis:
.. code-block:: shell
>>> arr = mx.arange(8)
>>> arr.shape
[8]
>>> arr[None].shape
[1, 8]
You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
works just as in NumPy.
Other functions which may be useful for indexing arrays are :func:`take` and
:func:`take_along_axis`.
Differences from NumPy
----------------------
.. Note::
MLX indexing is different from NumPy indexing in two important ways:
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is not yet supported.
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[2] = 0
>>> a
array([1, 2, 0], dtype=int32)
Just as in NumPy, in place updates will be reflected in all references to the
same array:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> b = a
>>> b[2] = 0
>>> b
array([1, 2, 0], dtype=int32)
>>> a
array([1, 2, 0], dtype=int32)
Transformations of functions which use in-place updates are allowed and work as
expected. For example:
.. code-block:: python
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.

View File

@@ -0,0 +1,144 @@
.. _lazy eval:
Lazy Evaluation
===============
.. currentmodule:: mlx.core
Why Lazy Evaluation
-------------------
When you perform operations in MLX, no computation actually happens. Instead a
compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to
integrate compilation for future performance enhancements.
Only Compute What You Use
^^^^^^^^^^^^^^^^^^^^^^^^^
In MLX you do not need to worry as much about computing outputs that are never
used. For example:
.. code-block:: python
def fun(x):
a = fun1(x)
b = expensive_fun(a)
return a, b
y, _ = fun(x)
Here, we never actually compute the output of ``expensive_fun``. Use this
pattern with care though, as the graph of ``expensive_fun`` is still built, and
that has some cost associated to it.
Similarly, lazy evaluation can be beneficial for saving memory while keeping
code simple. Say you have a very large model ``Model`` derived from
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
Typically, this will initialize all of the weights as ``float32``, but the
initialization does not actually compute anything until you perform an
:func:`eval`. If you update the model with ``float16`` weights, your maximum
consumed memory will be half that required if eager computation was used
instead.
This pattern is simple to do in MLX thanks to lazy computation:
.. code-block:: python
model = Model() # no memory used yet
model.load_weights("weights_fp16.safetensors")
When to Evaluate
----------------
A common question is when to use :func:`eval`. The trade-off is between
letting graphs get too large and not batching enough useful work.
For example:
.. code-block:: python
for _ in range(100):
a = a + b
mx.eval(a)
b = b * 2
mx.eval(b)
This is a bad idea because there is some fixed overhead with each graph
evaluation. On the other hand, there is some slight overhead which grows with
the compute graph size, so extremely large graphs (while computationally
correct) can be costly.
Luckily, a wide range of compute graph sizes work pretty well with MLX:
anything from a few tens of operations to many thousands of operations per
evaluation should be okay.
Most numerical computations have an iterative outer loop (e.g. the iteration in
stochastic gradient descent). A natural and usually efficient place to use
:func:`eval` is at each iteration of this outer loop.
Here is a concrete example:
.. code-block:: python
for batch in dataset:
# Nothing has been evaluated yet
loss, grad = value_and_grad_fn(model, batch)
# Still nothing has been evaluated
optimizer.update(model, grad)
# Evaluate the loss and the new parameters which will
# run the full gradient computation and optimizer update
mx.eval(loss, model.parameters())
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.
Also, calling :func:`eval` on an array or set of arrays multiple times is
perfectly fine. This is effectively a no-op.
.. warning::
Using scalar arrays for control-flow will cause an evaluation.
Here is an example:
.. code-block:: python
def fun(x):
h, y = first_layer(x)
if y > 0: # An evaluation is done here!
z = second_layer_a(h)
else:
z = second_layer_b(h)
return z
Using arrays for control flow should be done with care. The above example works
and can even be used with gradient transformations. However, this can be very
inefficient if evaluations are done too frequently.

108
docs/src/usage/numpy.rst Normal file
View File

@@ -0,0 +1,108 @@
.. _numpy:
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
.. code-block:: python
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
.. code-block:: python
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
JAX
---
JAX fully supports the buffer protocol.
.. code-block:: python
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)

View File

@@ -40,6 +40,9 @@ automatically evaluate the array.
>> np.array(c) # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
Function and Graph Transformations
----------------------------------
@@ -62,10 +65,3 @@ and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
Devices and Streams
-------------------

View File

@@ -0,0 +1,81 @@
.. _saving_and_loading:
Saving and Loading Arrays
=========================
.. currentmodule:: mlx.core
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
* - Safetensors
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
Here's an example of saving a single array to a file:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> mx.save("array", a)
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
is automatically added). Including the extension is optional; if it is missing
it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.savez("arrays", a, b=b)
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
as arguments. If the keywords are missing, then default names will be
provided. This can be loaded with:
.. code-block:: shell
>>> mx.load("arrays.npz")
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
In this case :func:`load` returns a dictionary of names to arrays.
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
.. code-block:: shell
>>> a = mx.array([1.0])
>>> b = mx.array([2.0])
>>> mx.save_safetensors("arrays", {"a": a, "b": b})

View File

@@ -0,0 +1,78 @@
.. _unified_memory:
Unified Memory
==============
.. currentmodule:: mlx.core
Apple silicon has a unified memory architecture. The CPU and GPU have direct
access to the same memory pool. MLX is designed to take advantage of that.
Concretely, when you make an array in MLX you don't have to specify its location:
.. code-block:: python
a = mx.random.normal((100,))
b = mx.random.normal((100,))
Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
.. code-block:: python
mx.add(a, b, stream=mx.cpu)
mx.add(a, b, stream=mx.gpu)
In the above, both the CPU and the GPU will perform the same add
operation. The operations can (and likely will) be run in parallel since
there are no dependencies between them. See :ref:`using_streams` for more
information the semantics of streams in MLX.
In the above ``add`` example, there are no dependencies between operations, so
there is no possibility for race conditions. If there are dependencies, the
MLX scheduler will automatically manage them. For example:
.. code-block:: python
c = mx.add(a, b, stream=mx.cpu)
d = mx.add(a, c, stream=mx.gpu)
In the above case, the second ``add`` runs on the GPU but it depends on the
output of the first ``add`` which is running on the CPU. MLX will
automatically insert a dependency between the two streams so that the second
``add`` only starts executing after the first is complete and ``c`` is
available.
A Simple Example
~~~~~~~~~~~~~~~~
Here is a more interesting (albeit slightly contrived example) of how unified
memory can be helpful. Suppose we have the following computation:
.. code-block:: python
def fun(a, b, d1, d2):
x = mx.matmul(a, b, stream=d1)
for _ in range(500):
b = mx.exp(b, stream=d2)
return x, b
which we want to run with the following arguments:
.. code-block:: python
a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))
The first ``matmul`` operation is a good fit for the GPU since it's more
compute dense. The second sequence of operations are a better fit for the CPU,
since they are very small and would probably be overhead bound on the GPU.
If we time the computation fully on the GPU, we get 2.8 milliseconds. But if we
run the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only
about 1.4 milliseconds, about twice as fast. These times were measured on an M1
Max.

View File

@@ -1,3 +1,5 @@
.. _using_streams:
Using Streams
=============

View File

@@ -57,7 +57,7 @@ void array_basics() {
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);
// To actually run the compuation you must evaluate `z`.
// To actually run the computation you must evaluate `z`.
// Under the hood, mlx records operations in a graph.
// The variable `z` is a node in the graph which points to its operation
// and inputs. When `eval` is called on an array (or arrays), the array and

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.24)
cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX)
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
endif()

View File

@@ -26,7 +26,7 @@ namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -91,21 +91,24 @@ void axpby_impl(
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& out_arr) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -175,7 +178,10 @@ void axpby_impl_accelerate(
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -189,13 +195,15 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
}
// Fall back to common backend if specializations are not available
eval(inputs, out);
eval(inputs, outarr);
}
#else // Accelerate not avaliable
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
eval(inputs, out);
}
@@ -208,8 +216,11 @@ void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
#ifdef _METAL_
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -254,7 +265,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
// those in the kernel declaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
@@ -287,7 +298,7 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
@@ -295,7 +306,9 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
#else // Metal is not available
/** Fail evaluation on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -306,13 +319,13 @@ void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
array Axpby::jvp(
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
@@ -321,32 +334,33 @@ array Axpby::jvp(
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
/** Vectorize primitve along given axis */
std::pair<array, int> Axpby::vmap(
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
@@ -39,14 +39,16 @@ class Axpby : public Primitive {
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
/** The Jacobian-vector product. */
array jvp(
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -54,16 +56,17 @@ class Axpby : public Primitive {
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) override;
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
/**
* The primitive must know how to vectorize itself accross
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<array, int> vmap(
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -80,7 +83,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
} // namespace mlx::core

View File

@@ -59,5 +59,5 @@ template <typename T>
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);

View File

@@ -23,7 +23,7 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``

View File

@@ -1,4 +1,5 @@
# Copyright © 2023 Apple Inc.
import mlx.core as mx
from .mlx_sample_extensions import *

View File

@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"

View File

@@ -1,8 +1,9 @@
# Copyright © 2023 Apple Inc.
from mlx import extension
from setuptools import setup
from mlx import extension
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
@@ -14,5 +15,5 @@ if __name__ == "__main__":
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.7",
python_requires=">=3.8",
)

View File

@@ -1,8 +1,9 @@
# Copyright © 2023 Apple Inc.
import mlx.core as mx
import time
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000
@@ -40,6 +41,6 @@ error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
f"Throughput {throughput:.5f} (it/s)"
)

View File

@@ -1,8 +1,9 @@
# Copyright © 2023 Apple Inc.
import mlx.core as mx
import time
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000

View File

@@ -3,23 +3,25 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
target_sources(

View File

@@ -9,7 +9,7 @@
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
auto buffer = allocator().malloc(size, /* allow_swap */ true);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,7 +22,7 @@ void free(Buffer buffer) {
return allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size) {
Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)};
}
@@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";

View File

@@ -37,9 +37,9 @@ void free(Buffer buffer);
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base clase for a memory allocator. */
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0;
Allocator() = default;
@@ -55,7 +55,7 @@ Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
private:

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <functional>
@@ -6,6 +6,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
@@ -21,6 +22,12 @@ std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -32,7 +39,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
@@ -40,6 +47,34 @@ array::array(
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
}
return outputs;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
@@ -47,6 +82,13 @@ array::array(std::initializer_list<float> data)
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
@@ -58,12 +100,26 @@ array::array(
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
void array::eval(bool retain_graph /* = false */) {
mlx::core::eval({*this}, retain_graph);
void array::eval() {
mlx::core::eval({*this});
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -108,6 +164,14 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(array other) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
@@ -116,21 +180,43 @@ array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
array::ArrayDesc::ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(shape);
for (auto& in : inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
// Needed because the Primitive type used in array.h is incomplete and the
// compiler needs to see the call to the desctructor after the type is complete.
array::ArrayDesc::~ArrayDesc() = default;
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
@@ -42,6 +41,9 @@ class array {
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
std::initializer_list<T> data,
@@ -116,11 +118,14 @@ class array {
};
/** Evaluate the array. */
void eval(bool retain_graph = false);
void eval();
/** Get the value from a scalar array. */
template <typename T>
T item(bool retain_graph = false);
T item();
template <typename T>
T item() const;
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
@@ -128,11 +133,7 @@ class array {
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
explicit ArrayIterator(const array& arr, int idx = 0);
reference operator*() const;
@@ -154,8 +155,8 @@ class array {
};
private:
int idx;
const array& arr;
int idx;
};
ArrayIterator begin() const {
@@ -174,7 +175,19 @@ class array {
array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -182,6 +195,11 @@ class array {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
}
/** A unique identifier for an arrays primitive. */
std::uintptr_t primitive_id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_->primitive.get());
}
struct Data {
allocator::Buffer buffer;
deleter_t d;
@@ -209,6 +227,11 @@ class array {
return *(array_desc_->primitive);
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
@@ -219,12 +242,42 @@ class array {
return array_desc_->inputs;
};
/** A non-const reference to the array's inputs so that they can be used to
* edit the graph. */
std::vector<array>& editable_inputs() {
std::vector<array>& inputs() {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
auto idx = array_desc_->position;
std::vector<array> outputs;
outputs.reserve(siblings().size() + 1);
outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx);
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@@ -245,6 +298,12 @@ class array {
return array_desc_->data->buffer;
};
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
@@ -265,9 +324,7 @@ class array {
array_desc_->is_tracer = is_tracer;
}
// Check if the array is a tracer array
bool is_tracer() const {
return array_desc_->is_tracer;
}
bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
@@ -287,6 +344,8 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
@@ -301,7 +360,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::unique_ptr<Primitive> primitive{nullptr};
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -323,22 +382,34 @@ class array {
Flags flags;
std::vector<array> inputs;
// An array to keep track of the siblings from a multi-output
// primitive.
std::vector<array> siblings;
// The arrays position in the output list
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
~ArrayDesc();
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and a the list of array's inputs for the primitive.
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
@@ -381,11 +452,23 @@ array::array(
}
template <typename T>
T array::item(bool retain_graph /* = false */) {
T array::item() {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
eval(retain_graph);
eval();
return *data<T>();
}
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
return *data<T>();
}

View File

@@ -4,6 +4,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -29,12 +29,16 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -42,6 +46,14 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -50,21 +62,34 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
M,
N,
K,
1.0f, // alpha
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}
inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
// TODO: Update to utilize BNNS broadcasting
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -72,11 +97,19 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ 1.0,
/* float beta = */ 0.0,
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
@@ -157,6 +190,12 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
BNNSFilterDestroy(bnns_filter);
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -166,4 +205,16 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns(inputs[0], inputs[1], out);
}
} // namespace mlx::core
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
@@ -17,6 +17,12 @@
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
// Use the default implementation for the following primitives
@@ -26,12 +32,18 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
@@ -39,16 +51,24 @@ DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
@@ -57,21 +77,11 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else if (in.dtype() == int32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
@@ -124,12 +134,8 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -140,12 +146,8 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -156,12 +158,8 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -172,12 +170,8 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -188,12 +182,8 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -204,12 +194,8 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -221,30 +207,23 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
allocfn();
set_unary_output_data(in, out);
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
allocfn();
set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
allocfn();
set_unary_output_data(in, out);
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
allocfn();
set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
@@ -256,12 +235,8 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -272,12 +247,8 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -326,12 +297,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
@@ -358,12 +325,8 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
@@ -387,12 +350,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
@@ -404,47 +363,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -474,13 +392,8 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
unary(in, out, [](auto x) { return -x; });
}
@@ -493,8 +406,14 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
}
@@ -535,12 +454,8 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -551,12 +466,8 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -567,12 +478,8 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
@@ -583,12 +490,8 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
@@ -643,12 +546,8 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -659,12 +558,8 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);

View File

@@ -0,0 +1,103 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <simd/vector.h>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void _qmm_t_4_64(
float* result,
const float* x,
const uint32_t* w,
const float* scales,
const float* biases,
int M,
int N,
int K) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
bool condition =
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int N = out.shape(-1);
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -3,16 +3,20 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
)

View File

@@ -6,6 +6,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -75,6 +76,61 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void DivMod::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
}
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -82,6 +138,47 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
@@ -154,14 +251,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x < y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
@@ -40,29 +39,83 @@ void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) {
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}
}
@@ -73,6 +126,12 @@ struct UseDefaultBinaryOp {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op>
@@ -89,6 +148,18 @@ struct DefaultVectorScalar {
a++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
};
template <typename T, typename U, typename Op>
@@ -105,6 +176,18 @@ struct DefaultScalarVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
};
template <typename T, typename U, typename Op>
@@ -121,6 +204,18 @@ struct DefaultVectorVector {
b++;
}
}
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op>

View File

@@ -0,0 +1,536 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out_a.size(); ++i) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
switch (out_a.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Ops... ops) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, ops...);
break;
case int8:
binary_op<int8_t>(a, b, outputs, ops...);
break;
case int16:
binary_op<int16_t>(a, b, outputs, ops...);
break;
case int32:
binary_op<int32_t>(a, b, outputs, ops...);
break;
case int64:
binary_op<int64_t>(a, b, outputs, ops...);
break;
case float16:
binary_op<float16_t>(a, b, outputs, ops...);
break;
case float32:
binary_op<float>(a, b, outputs, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, ops...);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -0,0 +1,59 @@
// Copyright © 2023-2024 Apple Inc.
#include <queue>
#include "mlx/primitives.h"
namespace mlx::core {
// Build the real tape
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
const std::vector<array>& trace_tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
std::queue<array> tape;
for (auto& a : trace_tape) {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
tape.push(
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
trace_to_real.insert({a.id(), tape.back()});
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return {tape, outputs};
}
void Compiled::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the a real tape from the tracers
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
// Run the tape
while (!tape.empty()) {
auto a = std::move(tape.front());
tape.pop();
auto outputs = a.outputs();
a.primitive().eval_cpu(a.inputs(), outputs);
a.detach();
}
// Copy results into outputs
for (int o = 0; o < real_outputs.size(); ++o) {
outputs[o].copy_shared_buffer(real_outputs[o]);
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,7 @@
#include <cassert>
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
@@ -357,7 +357,7 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Peform gemm
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
@@ -459,7 +459,7 @@ void explicit_gemm_conv_2D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Peform gemm
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A

View File

@@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
dst.set_data(
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
src.data_size(),
src.strides(),
src.flags());
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:

View File

@@ -1,6 +1,12 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
@@ -12,6 +18,12 @@
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
DEFAULT(Abs)
@@ -29,17 +41,24 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
@@ -50,6 +69,8 @@ DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
@@ -59,9 +80,12 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
@@ -71,6 +95,7 @@ DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
@@ -79,16 +104,14 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
namespace {
inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
@@ -106,9 +129,17 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -117,16 +148,41 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
M,
N,
K,
1.0f, // alpha
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -5,7 +5,7 @@
#include <utility>
#include "mlx/allocator.h"
#include "mlx/load.h"
#include "mlx/io/load.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -13,7 +13,7 @@ namespace mlx::core {
namespace {
template <const uint8_t scalar_size>
void swap_endianess(uint8_t* data_bytes, size_t N) {
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
@@ -39,13 +39,13 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianess<2>(out.data<uint8_t>(), out.data_size());
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianess<4>(out.data<uint8_t>(), out.data_size());
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianess<8>(out.data<uint8_t>(), out.data_size());
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
}
}

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/erf.h"
#include "mlx/backend/common/threefry.h"
@@ -167,6 +168,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, [](auto x) { return std::ceil(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
@@ -220,22 +232,38 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
});
@@ -252,17 +280,14 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
});
@@ -287,6 +312,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, [](auto x) { return std::floor(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -342,6 +378,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, [](auto x) { return !x; });
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -444,6 +494,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
}
}
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, RoundOp());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -540,6 +601,58 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

153
mlx/backend/common/qrf.cpp Normal file
View File

@@ -0,0 +1,153 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -0,0 +1,285 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
*result = sum;
result++;
}
x += K;
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
array out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
scales.data<float16_t>(),
biases.data<float16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
scales.data<bfloat16_t>(),
biases.data<bfloat16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
} // namespace mlx::core

View File

@@ -126,7 +126,7 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
(x.flags().row_contiguous || x.flags().col_contiguous)) {
x.flags().contiguous) {
return ContiguousAllReduce;
}

View File

@@ -0,0 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides().back() == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -53,15 +53,35 @@ struct SignOp {
}
};
struct RoundOp {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
}
template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
set_unary_output_data(a, out);
T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);

Some files were not shown because too many files have changed in this diff Show More