Commit Graph

555 Commits

Author SHA1 Message Date
Angelos Katharopoulos
53e6a9367c
Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8
Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98
Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a
Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Md. Rasel Mandol
db6796ac61
simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256
vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Awni Hannun
63ab0ab580
version (#835) 2024-03-14 12:20:40 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8
route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0
Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
8b7532b9ab
fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
366478c560
fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a
Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
0e95b64942
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b
Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe
Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9
Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Awni Hannun
28301807c2
Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
74ed0974b3
Support 13.0+ with xcode 14.3 (#806)
* Support 13.0+ with xcode 14.3

* revert revert
2024-03-07 13:27:57 -08:00
Jagrit Digani
ec8a4864fa
Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
b7588fd5d7
fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Awni Hannun
f512b905c7
Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049
route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32
Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
AlexCheema
7762e07fde
Update function_transforms.rst (#796)
Fix typo in function_transforms.rst
2024-03-06 12:03:37 -08:00
Luca Arnaboldi
cbefd9129e
Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)
* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-06 08:02:41 -08:00
Angelos Katharopoulos
e39bebe13e
Fix reshaping of empty arrays (#791) 2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Awni Hannun
859ae15a54
Fix test (#785) 2024-03-04 23:02:27 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07
Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4
Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
c096a77b9b
revision bump (#778) 2024-03-04 13:41:53 -08:00
Awni Hannun
5121f028d9
nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00