Jagrit Digani
eab2685c67
Float mask update ( #1152 )
...
* Float mask update
* Update CPU impl
2024-05-23 17:20:44 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions ( #1129 )
...
* Added groups to 2-D convolutions. Only implemented for **some** specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
* fix channels condition
2024-05-22 20:01:44 -07:00
Abe Leininger
79ef49b2c2
add mx.trace ( #1143 ) ( #1147 )
...
* working c++ trace implementation
* updated throw + added overloads
* added python binding for trace function
* pre-commit reformatting
* add trace to docs
* resolve comments
* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
d568c7ee36
Rename block sparse ( #1149 )
...
* block_sparse_mm to gather_mm
* rename
* nit
* nit
2024-05-22 07:48:34 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm ( #1124 )
2024-05-16 15:24:14 -07:00
Awni Hannun
863039da4c
Allow scatter type exception to be caught by checking in op ( #1077 )
...
* allow exception to be caught in main thread
* only for gpu
* more detailed scatter error
2024-05-13 17:43:53 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d ( #993 )
...
* added conv3d
added conv3d
implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D
* incorporated reviewer comments
* fixed test
* reduced tensor shapes in test for conv3d
* Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator ( #1100 )
...
* cpu and gpu impl
* add mx.conj and array.conj()
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation ( #1079 )
...
* Added ArcTan2 operation
* Cleanup, bug fixes from code review
* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
f390957685
Block sparse mm ( #1058 )
2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel ( #1061 )
2024-05-01 18:19:11 -07:00
Angelos Katharopoulos
8db7161c94
Bug fix in quantize ( #1054 )
2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896
fix slice update indexing ( #1053 )
2024-04-29 12:17:40 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d ( #948 )
...
* Add conv1d grouped convs on CPU
* Add GPU support
* Parallelize inside metal kernel
* clenaup
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* New unfold kernel + remove unused code
* Remove copy and refactor
* Update vjp and reuse steel gemm
* Fixed groups on cpu
* Fix metal validation
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops ( #1037 )
...
* bitwise ops
* fix tests
2024-04-26 22:03:42 -07:00
Angelos Katharopoulos
ec8578d41a
Fix quantization of all 0s ( #1028 )
2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees ( #1011 )
2024-04-22 11:17:49 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization ( #1016 )
2024-04-19 19:47:26 -07:00
Jagrit Digani
b18468bf81
Masked mm ( #978 )
...
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Nripesh Niketan
ffff671273
Update pre-commit hooks ( #984 )
2024-04-11 07:27:53 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch ( #977 )
2024-04-10 21:45:31 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid ( #961 )
2024-04-09 11:43:08 -07:00
Awni Hannun
42afe27e12
std and expm1 ( #973 )
...
* std and expm1
* actually add expm1
* fix linux
* fix vjp
* relax tol for linux test
* Add it to the compilable primitives
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
039da779d1
No quant reshape ( #957 )
...
* precise option on cpu
* remove print
* remove reshape in quant matmul
* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax ( #953 )
...
* precise softmax
* Add an equivalency check
* Make the threadgroup memory definition fixed
* precise cpu softmax
* precise option on cpu
* remove print
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
Cheng
bab5386306
Make ops aware of rvalues: astype/as_strided/copy/full ( #895 )
...
When compositing transforms lots of temporary of arrays will be created
and passed to next primitive, and by making ops accepting args by value
we can avoid lots of copies of temporary arrays.
2024-03-27 22:35:55 -07:00
Jagrit Digani
240d10699c
Implement negative padding in conv with slicing ( #907 )
...
* Implement negative padding with slicing
* Update mlx/ops.cpp
Co-authored-by: Awni Hannun <awni@apple.com>
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-26 14:59:19 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… ( #427 )
...
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285 .
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead ( #871 )
...
* some small overhead improvements
* use result_type in rms_norm
* remove release force
* fix + use non-vector version
* revert compile change
* fix ops
* a little more overhead
* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Cheng
f0ae00da12
Reduce implicit copies in make_array ( #874 )
...
1. Move shapes into outputs instead of copying them.
2. Pass primitive by const ref as it is always copied into outputs, which
removes a copy when calling make_array.
2024-03-22 06:29:16 -07:00
Jagrit Digani
a5681ebc52
Update set item ( #861 )
...
* Update mlx_set_item to handle regular slices without expanding
* Refactor ellipsis handling
* Route mlx_set_item to slice_update where possible
* Update mlx_scatter_args_slice
* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Jagrit Digani
cec8661113
Add a SliceUpdate op and primitive ( #850 )
...
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes ( #802 )
2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5
Reshape improvement ( #818 )
2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb
No copy gems ( #801 )
...
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
Angelos Katharopoulos
e39bebe13e
Fix reshaping of empty arrays ( #791 )
2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c
Improved quantized matrix vector product ( #786 )
2024-03-05 17:32:19 -08:00
Awni Hannun
5121f028d9
nice tensordot for mlx c ( #782 )
2024-03-04 09:51:02 -08:00
Angelos Katharopoulos
8e281c76c3
Fix the top-k op ( #768 )
2024-03-01 22:08:43 -08:00
Awni Hannun
d5964a2710
bindings for memory info ( #761 )
...
* bindings for memory info
* update api
* keep cache low if requested
* fix default
* nit in ops error
2024-03-01 19:51:58 -08:00
Jagrit Digani
776c3d226d
Convolution update ( #651 )
...
* Init steel conv and update Conv primitive
* Update slow CPU implementation to support flipping and input dilation winograd conv routing
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-28 20:11:16 -08:00
Hinrik Snær Guðmundsson
08226ab491
added atleast *args input support ( #710 )
...
* added atleast list(array) input support
* function overloading implemented
* Refactoring
* fixed formatting
* removed pos_only
2024-02-26 11:17:59 -08:00
Awni Hannun
e6418781ab
Fix logsumexp edge case ( #740 )
...
* fix logsumexp
* fix inf constant
* also fix power grad
* fix ternary dispatch
2024-02-25 08:39:55 -08:00
Noah Farr
d729a1991b
Fix arange with inf step ( #686 )
...
* Fix case for step=inf in arange and add inf check for start/stop
* Add test cases for arange
* Update ops.cpp to include climits header
* Fix arange
* Fix formatting
* Refactor
* Add missing include
2024-02-23 06:18:15 -08:00
Rifur13
126c9869c8
Implement the 'where' primitive for conditional selection ( #664 )
2024-02-22 15:10:48 -08:00
Hinrik Snær Guðmundsson
f883fcede0
Added support for atleast_1d, atleast_2d, atleast_3d ( #694 )
2024-02-19 09:40:52 -08:00
toji
85143fecdd
improved error msg for invalid axis(mx.split
) ( #685 )
...
* improved error msg for invalid axis(`mx.split`)
* Apply suggestions from code review
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* fixed formatting issue
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8
Adds device context manager ( #679 )
2024-02-14 14:14:58 -08:00
Awni Hannun
1eb04aa23f
Fix empty array construction in cpp ( #684 )
2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91
Return empty array when repeats is 0 in mx.repeat ( #681 )
...
* Return empty array when repeats is 0
* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00