Commit Graph

296 Commits

Author SHA1 Message Date
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization (#1016) 2024-04-19 19:47:26 -07:00
Awni Hannun
ed83908931
fix gguf loading quants (#1014)
* fix gguf loading quants

* fix nanobind install

* actual fix
2024-04-19 12:24:07 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Angelos Katharopoulos
dce4bd74a4
Add ArrayDesc destructor to avoid possible stack overflow (#982) 2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks (#984) 2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff
Try a stack-based DFS for eval (#980)
* rebase

* nit

* fix eval in vmap
2024-04-10 17:05:13 -07:00
Awni Hannun
99abb9eff4
Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9
use string (#976) 2024-04-09 11:22:00 -07:00
Awni Hannun
42afe27e12
std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
039da779d1
No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Awni Hannun
741eb28443
fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8
Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e
Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d
Fix typo in qmm check (#940) 2024-03-31 19:15:44 -07:00
Awni Hannun
8915901966
Donation bug (#933)
* donation

* buf

* fix bug in softmax

* comment

* remove print
2024-03-30 10:08:54 -07:00
Cheng
913b19329c
Add missing && when forwarding args (#925)
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0
Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Cheng
bab5386306
Make ops aware of rvalues: astype/as_strided/copy/full (#895)
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635
Fix OOB read in qmv when non-divisible by blocksize (#917) 2024-03-27 22:18:35 -07:00
Cheng
90dfa43ff1
Don't use make_unique to create shared_ptr (#902)
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3
Fix race in multi-stream eval (#911)
* maybe fix race

* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63
Remove duplicate defines of StreamOrDevice and is_big_endian (#892) 2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c
Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Angelos Katharopoulos
9948eddf11
Fix nan and improve speed for qvm (#903) 2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01
Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519
Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work.

Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac
Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
Cheng
9663c22fe9
Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Cheng
f0ae00da12
Reduce implicit copies in make_array (#874)
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
   removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98
Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a
Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0
Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256
vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8
route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Awni Hannun
8b7532b9ab
fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
0e95b64942
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b
Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe
Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9
Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Jagrit Digani
ec8a4864fa
Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
f512b905c7
Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049
route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32
Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
Angelos Katharopoulos
e39bebe13e
Fix reshaping of empty arrays (#791) 2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Brian Keene
0787724c44
Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07
Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4
Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
5121f028d9
nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00
Awni Hannun
bc06cb9ff6
Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Angelos Katharopoulos
8e281c76c3
Fix the top-k op (#768) 2024-03-01 22:08:43 -08:00
Awni Hannun
d5964a2710
bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Jagrit Digani
776c3d226d
Convolution update (#651)
* Init steel conv and update Conv primitive

* Update slow CPU implementation to support flipping and input dilation winograd conv routing

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Awni Hannun
f5f18b704f
fix temporary bug (#752) 2024-02-27 17:44:39 -08:00
Awni Hannun
56ba3ec40e
fix cpu compile on older OS (#747) 2024-02-26 22:20:53 -08:00
Hinrik Snær Guðmundsson
08226ab491
added atleast *args input support (#710)
* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
2024-02-26 11:17:59 -08:00
Awni Hannun
e6418781ab
Fix logsumexp edge case (#740)
* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Awni Hannun
ac02cf33bd
Fix some issues using MLX in C++ (#739)
* fix preamble build

* fix some issues with using MLX as a dep in C++
2024-02-24 22:20:57 -08:00
Noah Farr
d729a1991b
Fix arange with inf step (#686)
* Fix case for step=inf in arange and add inf check for start/stop

* Add test cases for arange

* Update ops.cpp to include climits header

* Fix arange

* Fix formatting

* Refactor

* Add missing include
2024-02-23 06:18:15 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection (#664) 2024-02-22 15:10:48 -08:00