Commit Graph

378 Commits

Author SHA1 Message Date
Cheng
56a4eaed72
Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7
Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6
Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
5be5daa6ef
Use compiled function in Sigmoid module (#1116) 2024-05-14 06:25:57 -07:00
Cheng
60cb11764e
Use correct module type in quantized.py (#1115) 2024-05-14 06:25:42 -07:00
Cheng
cbd5445ea7
The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Alex Barron
2e158cf6d0
Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
b21242faf1
Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4
Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Awni Hannun
9814a2ae12
fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a
add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00
Awni Hannun
21623156a3
Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0
feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Awni Hannun
b00ac960b4
change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Jagrit Digani
f390957685
Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797
Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c
Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Angelos Katharopoulos
8db7161c94
Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896
fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc
[Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b
Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
5bfe89bdb1
Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
Awni Hannun
771575d27b
Expose function to clear memory cache (#1032)
* expose function to clear memory cache

* fix linux build

* fix metal tests
2024-04-24 16:48:51 -07:00
Angelos Katharopoulos
ec8578d41a
Fix quantization of all 0s (#1028) 2024-04-24 00:40:42 -07:00
Aneesh Shetty
d0dbfe0b97
Adds radians and degrees (#1011) 2024-04-22 11:17:49 -07:00
Awni Hannun
3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00
Angelos Katharopoulos
84d61d27aa
Make sure 0 is represented in the quantization (#1016) 2024-04-19 19:47:26 -07:00
Angelos Katharopoulos
ef5f7d1aea
Fix buffer protocol buffer size designation (#1010) 2024-04-19 06:06:13 -07:00
Jagrit Digani
85c8a91a27
Fix mask broadcasting bug and add relevant test (#1003) 2024-04-17 17:33:48 -07:00
Piotr Rybiec
581b699ac9
avgpool, not maxpool (#1002) 2024-04-17 08:26:22 -07:00
Awni Hannun
8a0677d56d
Shared events for synchronization + async eval (#998)
* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
2024-04-17 06:16:02 -07:00
Jagrit Digani
b18468bf81
Masked mm (#978)
* Add block masked matmul op and primitive
2024-04-16 14:45:39 -07:00
Shiyu
107ba2891a
gelu tanh approx (#989)
* gelu tanh approx

* gelu tanh approx

* replace gelu approx with tanh approach

* fix comments

* fix comment
2024-04-15 19:49:00 -07:00
Awni Hannun
cd9e184529
Quantize embedding (#994)
* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
2024-04-15 16:42:10 -07:00
Alex Barron
2e7c02d5cd
Metal FFT for powers of 2 up to 2048 (#915)
* add Metal FFT for powers of 2

* skip GPU test on linux

* fix contiguity bug

* address comments

* Update mlx/backend/metal/fft.cpp

* Update mlx/backend/metal/fft.cpp

* fix bug in synch

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-11 21:40:06 -07:00
Awni Hannun
ae18326533
No copy command encoder (#986)
* no copy command encoder

* up layer norm test tolerances
2024-04-11 21:15:36 -07:00
Awni Hannun
12d4507ee3
Explicit barriers with concurrent dispatch (#977) 2024-04-10 21:45:31 -07:00
Shiyu
061cf9a4ce
Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
Awni Hannun
99abb9eff4
Async eval (#972) 2024-04-09 18:34:00 -07:00
Luca Arnaboldi
fffe072028
Implementation of mlx.random.multivariate_normal (#502) (#877)
* Implementation of mlx.random.multivariate_normal (#502)

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Updated typo in docstring

* Restricted multivariate_normal to  float32

* Generic mean and variance shapes

* Review edits

* Update mlx/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/random.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Test for ndim of mean and cov

* nits

* smaller size for test

* fix broadcasted sampling

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-09 13:50:12 -07:00
Abe Leininger
a1a31eed27
Add mx.meshgrid (#961) 2024-04-09 11:43:08 -07:00
Awni Hannun
42afe27e12
std and expm1 (#973)
* std and expm1

* actually add expm1

* fix linux

* fix vjp

* relax tol for linux test

* Add it to the compilable primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-08 14:26:01 -07:00
Awni Hannun
76e63212ff
Enable bfloat scan (#974)
* enable bfloat scan
* fix tests
2024-04-08 12:29:19 -07:00
Awni Hannun
aac2f9fb61
Improve profiling with gpu tracing (#969)
* improve profiling with gpu tracing

* fix for linux

* nit

* doc fix

* fix example
2024-04-07 21:47:43 -07:00
Awni Hannun
039da779d1
No quant reshape (#957)
* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
2024-04-04 11:52:12 -07:00
Awni Hannun
d88d2124b5
segfaut layer norm grad (#955) 2024-04-04 10:59:15 -07:00
Awni Hannun
e142aaf8a1
Option for precise softmax (#953)
* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-04-04 08:32:35 -07:00
AmirHossein_Razlighi
0caf35f4b8
Better exceptions in case of invalid operations on mlx.core.array (#910) (#926)
* Nicer exceptions for ops on non-arrays
2024-04-02 21:11:24 -07:00
Angelos Katharopoulos
3fc993f82d
Properly handle negative axes in python vmap (#944) 2024-04-02 18:07:23 -07:00
Awni Hannun
741eb28443
fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
Angelos Katharopoulos
1a87dc5ea8
Fix compile fusion for multi-output edge cases (#950)
* Fix compile fusion for multi-output edge cases

* Add a test for multi-output compile
2024-04-02 08:42:31 -07:00
Awni Hannun
2427fa171e
Fix cpu compile (#934)
* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
2024-04-01 17:37:12 -07:00
Jagrit Digani
639e06e1f3
Indexing bug fix (#947)
* Fix axes accounting

* Add tests
2024-04-01 12:18:50 -07:00
Angelos Katharopoulos
02fedbf1da
Fix array initialization from list (#942)
* Fix array initialization from list

* Change the error message in the test
2024-04-01 06:27:52 -07:00
Angelos Katharopoulos
110d9b149d
Layer norm grad fix donation bug (#941)
* add layer norm grad test

* Fix donation bug in layernorm vjp

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-01 06:15:50 -07:00
AmirHossein_Razlighi
f48bc496c7
Comparing python objects (such as list/tuple) with mlx.core.array (#920)
* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-29 06:52:30 -07:00
Angelos Katharopoulos
5f9ba3019f
Fix qmm_t for unaligned cases (#923) 2024-03-28 15:34:57 -07:00
Cheng
46caf0bef0
Remove unnecessary string copies (#891)
1. Use string_view instead of string when there is no need for copy.
2. Otherwise move string when possible.
2024-03-28 13:14:59 -07:00
Cheng
a7b404ff53
Use uintptr_t instead of size_t to store funtion id (#916)
Also does some small cleanup of the compile cache code.
2024-03-28 06:37:59 -07:00
AmirHossein_Razlighi
d611251502
Support Chaining for some of functionalities of nn.Module (#885) (#897)
* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Cheng
f30b659291
Make MLX build on x64 macOS (#901)
The arm64 macbook pros are heavy and I usually care my intel one for
mobile, it would be nice if I can play with MLX on it.

To build with x64, user must pass `MLX_ENABLE_X64_MAC` to cmake:
CMAKE_ARGS='-DMLX_ENABLE_X64_MAC=ON' python setup.py
2024-03-27 06:14:29 -07:00
Angelos Katharopoulos
29221fa238
Implement vjps for some primitives in the fast namespace (#883)
* Implement rope vjp in terms of rope
* RMSNormVJP primitive and kernel
* Add LayerNormVJP primitive and kernel
2024-03-26 16:35:34 -07:00
Jagrit Digani
925014b661
Fix multiblock sort limits (#906)
* Fix multiblock sort limits

* Fix metal validation error
2024-03-26 14:00:00 -07:00
Abdussamet Türker
5611e1a95e
Fix unsqueeze with None (#899)
* Fix unsqueeze with None

* Clean unnecessary files
2024-03-26 13:59:44 -07:00
Awni Hannun
570f2bf29e
pick up preivously set attributes (#905) 2024-03-26 11:19:59 -07:00
Luca Arnaboldi
a3ee03da01
Fixing random.normal for half-precision dtype #642 (#904)
* Fixing random.normal for half-precision dtype #642

* Update python/tests/test_random.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-26 09:58:27 -07:00
Jack Mousseau
8e686764ac
Ensure shape dimensions are within supported integer range (#566) (#704)
* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 13:29:45 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
1e16331d9c
post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Jagrit Digani
8e5a5a1ccd
Set item bug fix (#879)
* set item shaping bug fix

* Add extra tests
2024-03-22 12:11:17 -07:00
Angelos Katharopoulos
fcda3a0e66
Increase test tolerance for fast.layer_norm (#880) 2024-03-22 12:10:27 -07:00
Cheng
9663c22fe9
Do not store iostream in shared_ptr (#872)
There is no need to store iostream in shared_ptr, doing so adds the cost
of a heap allocation.
2024-03-22 06:54:45 -07:00
Awni Hannun
44390bd3d0
Bump (#869)
* bump

* fix none in a few ops
2024-03-21 13:56:56 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889
Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Angelos Katharopoulos
53e6a9367c
Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8
Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Jagrit Digani
a5681ebc52
Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Jagrit Digani
b219d12a6b
Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Md. Rasel Mandol
db6796ac61
simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246
Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Awni Hannun
16546c70d8
No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf
Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256
vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Angelos Katharopoulos
76c919b4ec
NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Jagrit Digani
5ad133f8bb
No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0
Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
366478c560
fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a
Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
0e95b64942
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
Awni Hannun
7c441600fe
Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
28301807c2
Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
b7588fd5d7
fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Luca Arnaboldi
cbefd9129e
Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)
* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-06 08:02:41 -08:00