Commit Graph

201 Commits

Author SHA1 Message Date
Gabrijel Boduljak
e54cbb7ba6
Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
Angelos Katharopoulos
40c108766b
Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Nripesh Niketan
0dbc4c7547
feat: Update pre-commit-config.yaml (#667) 2024-02-11 06:08:20 -08:00
Awni Hannun
b96be943dc
bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185
Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d
Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Awni Hannun
5c03efaf29
Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
LeonEricsson
7dccd42133
updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef
fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Noah Farr
5fd11c347d
Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Aryan Gupta
ef73393a19
Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
Angelos Katharopoulos
ea406d5e33
CI change (#645)
* CI update

* Skip large binary test for now

* Upgrade pip

* Add proper env variable skipping

* Update the CI

* Fix workflow name

* Set the low memory flag for the tests

* Change build process

* Add pip upgrade

* Use a venv

* Add a missing env activate

* Add setuptools

* Add twine upload back

* Re-enable automatic release builds
2024-02-07 06:04:34 -08:00
Awni Hannun
d40a04f8dc
minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd
Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Awni Hannun
5c3ac52dd7
fix test (#627) 2024-02-04 16:18:03 -08:00
Avikant Srivastava
11a9fd40f0
fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Daniel Strobusch
4fd2fb84a6
make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19
fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
AtomicVar
83f63f2184
Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
Awni Hannun
cb6156d35d
Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Awni Hannun
e88e474fd1
Reduce vmap + some fixes (#601) 2024-02-01 11:30:28 -08:00
David Koski
601c6d6aa8
Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365
Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb
Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
Vijay Krish
fcc5ac1c64
Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
nathan
bad67fec37
Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5
Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Awni Hannun
09b9275027
Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454
Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jacket
3f7aba8498
Implement diagonal operator (#562)
* Implement diagonal operator

This implements mx.diagonal in operator level, inspired by
@ManishAradwad.

* added `mx.diag` with tests

* corrected few things

* nits in bindings

* updates to diag

---------

Co-authored-by: ManishAradwad <manisharadwad@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 09:45:48 -08:00
Awni Hannun
3c2f192345
Propagate nans in binary ops (#579)
* propagate nans in binary ops

* handle empty matmul

* cpu minimum/maximum propagate nan

* benchmark maximum

* add min as well

* throw on negative indices with full

* verbose on linux

* fix matmul for zero K
2024-01-29 11:19:38 -08:00
Angelos Katharopoulos
37d98ba6ff
No gil eval (#565) 2024-01-26 22:03:52 -08:00
Awni Hannun
8993382aaa
Buffer Donation (#519)
* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
2024-01-26 16:30:33 -08:00
Awni Hannun
07f35c9d8a
Fix a few issues: docs for flatten, erf, dequantize validation (#560)
* doc flatten

* erf doc

* check values for dequantize

* format
2024-01-26 15:16:46 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
Awni Hannun
8fa6b322b9
Compile front-end (#476)
* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
2024-01-26 13:45:30 -08:00
David Koski
874b739f3c
Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
taher
077c1ee64a
QR factorization (#310)
* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-26 09:27:31 -08:00
Rifur13
2463496471
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-25 20:47:06 -08:00
Awni Hannun
f27ec5e097
More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
2024-01-24 09:58:33 -08:00
Awni Hannun
f30e63353a
Minor updates to address a few issues (#537)
* docs on arg indices return type

* arange with nan

* undo isort
2024-01-23 22:24:41 -08:00
Hazem Essam
37fc9db82c
Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137
Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc
Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
98c37d3a22
use axes in tensordot (#525) 2024-01-22 21:17:00 -08:00
Awni Hannun
7a34e46677
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-01-21 06:19:05 -08:00
Awni Hannun
d52383367a
format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d
Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00