Commit Graph

223 Commits

Author SHA1 Message Date
Shiyu
061cf9a4ce Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
Awni Hannun
741eb28443 fix a couple bugs (#952) 2024-04-02 12:07:41 -07:00
AmirHossein_Razlighi
d611251502 Support Chaining for some of functionalities of nn.Module (#885) (#897)
* add chaining support for some of the functionalities of "nn.Module"

* reformat

* change the return types

* remove return types

* add return type with forward referencing

* add tests for chaining

* add name to contributors

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update docstring

* update docstrings

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Awni Hannun
570f2bf29e pick up preivously set attributes (#905) 2024-03-26 11:19:59 -07:00
Daniel Strobusch
479051ce1c add numeric type hierarchy and issubdtype as well as a set_dtype meth… (#427)
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate

numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)).

Closes #285.

* nits in docs

* unify type category checking

* nits in docs

* nits in docs

* more docs nits

* fix callable type

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
1e16331d9c post nanobind docs fixes and some updates (#889)
* post nanobind docs fixes and some updates

* one more doc nit

* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b Reduce a little overhead (#871)
* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
2225374060 Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
Angelos Katharopoulos
53e6a9367c Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8 Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Awni Hannun
16546c70d8 No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
Awni Hannun
366478c560 fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
28301807c2 Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
cbcf44a4ca Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Piotr Rybiec
6a665ea6ed Dilation for convolutional layers (#766)
* add dilation parameter to Conv1d layer

* space here too

* add conv1d dilation test

* add dilation parameter for Conv2d layer

* conv2d dilation test
2024-03-04 06:43:00 -08:00
Awni Hannun
bc06cb9ff6 Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Awni Hannun
4494970f47 avoid nested closures in module (#759) 2024-02-29 09:39:52 -08:00
Awni Hannun
420ff2f331 Add back compiled function signatures and docstrings (#749)
* try to add back compiled function signatures and docstrings

* add indentation to docstring
2024-02-27 13:18:59 -08:00
Noah Kasmanoff
de3d2467a3 Update: Fast GeLU Approximation (#744)
* add: fast gelu approx

* fix docs

* Update gelu_fast_approx function documentation

* Update python/mlx/nn/layers/activations.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix: test gelu

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-26 21:08:50 -08:00
Awni Hannun
fe1dabf272 Fix compile with non standard types (#745)
* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-26 19:28:53 -08:00
Chime Ogbuji
3b661b7394 Add linear warmup and schedule joining for use with existing schedules (#721)
* Add linear warmup to schedules for use with existing schedules

* Changed parameters for simplicity of most common case (0 initial value)

* Added ScheduleJoiner and updated documentation

* ScheduleJoiner -> join_schedules (ala optax #)

* black compliance

* Different evaluation of schedules

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-26 07:28:48 -08:00
Gabrijel Boduljak
22364c40b7 Upsample2d (#414)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-23 09:55:04 -08:00
Awni Hannun
5798256fcf Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Srimukh Sripada
818cda16bc Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
Diogo
35431a4ac8 Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Gabrijel Boduljak
e54cbb7ba6 Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
LeonEricsson
7dccd42133 updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Aryan Gupta
ef73393a19 Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
AtomicVar
83f63f2184 Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
David Koski
601c6d6aa8 Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365 Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
nathan
bad67fec37 Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
0de5988f92 Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5 Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Awni Hannun
09b9275027 Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454 Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jagrit Digani
bf17ab5002 Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
David Koski
874b739f3c Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
Hazem Essam
37fc9db82c Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137 Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
d52383367a format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Anchen
f6feb61f92 feat: add support for saving safetensors in the save_weights (#497)
* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:19:33 -08:00