Awni Hannun
1873ffda01
Detect metal version and propagate correctly for JIT ( #1109 )
...
* detect metal version and propagate correctly for JIT
* remove softmax
* fix versions
2024-05-15 17:42:09 -07:00
Jagrit Digani
358e1fd6ab
Fused GEMM ( #1123 )
...
* Basic gemm working
* Update addmm
* Clear out steel_gemm and steel_addmm kernels
* Fuse and clear out gather gemm
* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
863039da4c
Allow scatter type exception to be caught by checking in op ( #1077 )
...
* allow exception to be caught in main thread
* only for gpu
* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111
No CPU option for binary minimization ( #1105 )
...
* no cpu build option
* docs
* fix
2024-05-13 16:08:11 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d ( #993 )
...
* added conv3d
added conv3d
implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D
* incorporated reviewer comments
* fixed test
* reduced tensor shapes in test for conv3d
* Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6
improve error messaging in eval ( #1101 )
2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator ( #1100 )
...
* cpu and gpu impl
* add mx.conj and array.conj()
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8b1906abd0
Add compiler flags to disable safetensors and gguf ( #1098 )
...
* with docs
* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605
Split encoders in non-concurrent context with a max ops per encoder ( #1085 )
...
* split encoders
* fix race
2024-05-09 16:21:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation ( #1079 )
...
* Added ArcTan2 operation
* Cleanup, bug fixes from code review
* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66
Update block offset adjustment to be in size_t ( #1087 )
2024-05-08 08:10:23 -07:00
Awni Hannun
21623156a3
Reset peak memory ( #1074 )
...
* reset peak memory
* fix linux
* nits in docs
2024-05-03 17:12:51 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info ( #1064 )
2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm ( #1058 )
2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel ( #1061 )
2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached ( #1059 )
...
* fix multi output leak
* ignore arrays that will be detached
* add some comments
* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info
( #1060 )
...
* device inof
* add variant
* fix linux
* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da
feat: metal formatting and pre-commit bump ( #1038 )
...
* feat: metal formatting and pre-commit bump
* add guards
* update
* more guards
* more guards
* smakk fix
* Refactor instantiation of ternary types in ternary.metal
* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94
Bug fix in quantize ( #1054 )
2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896
fix slice update indexing ( #1053 )
2024-04-29 12:17:40 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops ( #1037 )
...
* bitwise ops
* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759
fix order device -> scheduler ( #1039 )
2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs ( #1036 )
...
* start of C++ docs
* fix stream doc
* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache ( #1032 )
...
* expose function to clear memory cache
* fix linux build
* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
20a01bbd9f
Simplifying and improving qmm ( #1030 )
2024-04-24 13:07:45 -07:00
Angelos Katharopoulos
ec8578d41a
Fix quantization of all 0s ( #1028 )
2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees ( #1011 )
2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function ( #1006 )
...
* add synchronize function
* fix linux
* fix linux
* fix and fix docs
* fix test
* try synchronize in stream destroy
* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization ( #1016 )
2024-04-19 19:47:26 -07:00
Awni Hannun
ed83908931
fix gguf loading quants ( #1014 )
...
* fix gguf loading quants
* fix nanobind install
* actual fix
2024-04-19 12:24:07 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test ( #1003 )
2024-04-17 17:33:48 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval ( #998 )
...
* more async eval
* fix rebase
* try correct async eval
* fix async
* more tests for async eval
* use shared events for synchronization
* comment + cleanup
* with autorelease pool
* fix no metal build
* fix compile
* fix patch
* don't eval if asyn evale'd
* don't use is_evaled
* comments
* more multi stream tests
* try and cleanup use of is_evaled
* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm ( #978 )
...
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 ( #915 )
...
* add Metal FFT for powers of 2
* skip GPU test on linux
* fix contiguity bug
* address comments
* Update mlx/backend/metal/fft.cpp
* Update mlx/backend/metal/fft.cpp
* fix bug in synch
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder ( #986 )
...
* no copy command encoder
* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Angelos Katharopoulos
dce4bd74a4
Add ArrayDesc destructor to avoid possible stack overflow ( #982 )
2024-04-11 11:37:02 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks ( #984 )
2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch ( #977 )
2024-04-10 21:45:31 -07:00
Awni Hannun
8580d997ff
Try a stack-based DFS for eval ( #980 )
...
* rebase
* nit
* fix eval in vmap
2024-04-10 17:05:13 -07:00
Awni Hannun
99abb9eff4
Async eval ( #972 )
2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal ( #502 ) ( #877 )
...
* Implementation of mlx.random.multivariate_normal (#502 )
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Updated typo in docstring
* Restricted multivariate_normal to float32
* Generic mean and variance shapes
* Review edits
* Update mlx/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/random.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Test for ndim of mean and cov
* nits
* smaller size for test
* fix broadcasted sampling
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid ( #961 )
2024-04-09 11:43:08 -07:00
Awni Hannun
ae812350f9
use string ( #976 )
2024-04-09 11:22:00 -07:00
Awni Hannun
42afe27e12
std and expm1 ( #973 )
...
* std and expm1
* actually add expm1
* fix linux
* fix vjp
* relax tol for linux test
* Add it to the compilable primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan ( #974 )
...
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing ( #969 )
...
* improve profiling with gpu tracing
* fix for linux
* nit
* doc fix
* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
039da779d1
No quant reshape ( #957 )
...
* precise option on cpu
* remove print
* remove reshape in quant matmul
* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad ( #955 )
2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax ( #953 )
...
* precise softmax
* Add an equivalency check
* Make the threadgroup memory definition fixed
* precise cpu softmax
* precise option on cpu
* remove print
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Awni Hannun
741eb28443
fix a couple bugs ( #952 )
2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8
Fix compile fusion for multi-output edge cases ( #950 )
...
* Fix compile fusion for multi-output edge cases
* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e
Fix cpu compile ( #934 )
...
* fix one cpu bug, test for another
* format hooks
* simplify contiguity check for cpu compile
* fix
* add back donation
* comment
2024-04-01 17:37:12 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug ( #941 )
...
* add layer norm grad test
* Fix donation bug in layernorm vjp
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
Angelos Katharopoulos
9cbff5ec1d
Fix typo in qmm check ( #940 )
2024-03-31 19:15:44 -07:00
Awni Hannun
8915901966
Donation bug ( #933 )
...
* donation
* buf
* fix bug in softmax
* comment
* remove print
2024-03-30 10:08:54 -07:00
Cheng
913b19329c
Add missing && when forwarding args ( #925 )
...
Without the && args would be copied and perfect forwarding won't work.
2024-03-29 06:48:29 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases ( #923 )
2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0
Remove unnecessary string copies ( #891 )
...
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Jack Mousseau
45f636e759
Add Metal debug option and capture functions ( #707 )
...
* Add Metal debug option and capture functions
* Add brief Metal debugger documentation
* doc nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-28 09:40:31 -07:00
Cheng
a7b404ff53
Use uintptr_t instead of size_t to store funtion id ( #916 )
...
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
Cheng
bab5386306
Make ops aware of rvalues: astype/as_strided/copy/full ( #895 )
...
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Angelos Katharopoulos
aca7584635
Fix OOB read in qmv when non-divisible by blocksize ( #917 )
2024-03-27 22:18:35 -07:00
Cheng
90dfa43ff1
Don't use make_unique to create shared_ptr ( #902 )
...
The code compiled because shared_ptr's constructor actually accepts
unique_ptr.
2024-03-27 06:13:29 -07:00
Awni Hannun
dc175f08d3
Fix race in multi-stream eval ( #911 )
...
* maybe fix race
* comment
2024-03-26 16:36:36 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace ( #883 )
...
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Cheng
a789685c63
Remove duplicate defines of StreamOrDevice and is_big_endian ( #892 )
2024-03-26 15:15:11 -07:00
Jagrit Digani
240d10699c
Implement negative padding in conv with slicing ( #907 )
...
* Implement negative padding with slicing
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni@apple.com>
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits ( #906 )
...
* Fix multiblock sort limits
* Fix metal validation error
2024-03-26 14:00:00 -07:00
Angelos Katharopoulos
9948eddf11
Fix nan and improve speed for qvm ( #903 )
2024-03-26 10:41:45 -07:00
Luca Arnaboldi
a3ee03da01
Fixing random.normal for half-precision dtype #642 ( #904 )
...
* Fixing random.normal for half-precision dtype #642
* Update python/tests/test_random.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Cheng
28fcd2b519
Add missing && when forwarding args ( #894 )
...
Without the && args would be copied and perfect forwarding won't work.
Also add template utils to make sure the function only forwards array
and not vector<array>.
2024-03-25 14:55:54 -07:00
Jack Mousseau
8e686764ac
Ensure shape dimensions are within supported integer range ( #566 ) ( #704 )
...
* Ensure shape dimensions are within supported integer range (#566 )
* fix build
* fix rebase bug
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… ( #427 )
...
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285 .
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead ( #871 )
...
* some small overhead improvements
* use result_type in rms_norm
* remove release force
* fix + use non-vector version
* revert compile change
* fix ops
* a little more overhead
* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
6ee1112f30
Fix copy donation and add partial rope ( #881 )
2024-03-22 17:28:26 -07:00
Cheng
9663c22fe9
Do not store iostream in shared_ptr ( #872 )
...
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Cheng
f0ae00da12
Reduce implicit copies in make_array ( #874 )
...
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm ( #870 )
2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse ( #849 )
2024-03-21 13:18:27 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm ( #862 )
...
* fast rmsnorm
* no rms gpu
* kernel
* fix shared mem
* looped rms and donation in softmax
* Make the squaring in float32 to avoid underflow
* Fix the default StreamOrDevice for rope and rms_norm in fast
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98
Add missing && in eval ( #864 )
...
Without the && args would be copied and perfect forwarding won't work.
To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52
Update set item ( #861 )
...
* Update mlx_set_item to handle regular slices without expanding
* Refactor ellipsis handling
* Route mlx_set_item to slice_update where possible
* Update mlx_scatter_args_slice
* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a
Do not use static constexpr in header ( #863 )
...
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel ( #858 )
2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive ( #850 )
...
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0
Pass shape and inputs by value in array's constructor ( #853 )
...
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.
There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
then moved into ArrayDesc's member. So no copy happens.
So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind ( #839 )
...
* mostly builds
* most tests pass
* fix circle build
* add back buffer protocol
* includes
* fix for py38
* limit to cpu device
* include
* fix stubs
* move signatures for docs
* stubgen + docs fix
* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e
Some C++ code are not needed ( #841 )
...
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8
No reshape rope ( #838 )
...
* no reshape rope
* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive ( #822 )
2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256
vmap matmul and admm ( #836 )
2024-03-14 14:38:22 -07:00
Jagrit Digani
8dfc376c00
Strided reduce specialization for small reductions ( #826 )
...
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09
Add types and order in kernel name ( #831 )
2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8
route to fallback ( #828 )
2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4
Make reshape faster for row_contiguous cases ( #829 )
2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes ( #802 )
2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement ( #818 )
2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems ( #801 )
...
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive ( #809 )
...
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra
Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00