Commit Graph

203 Commits

Author SHA1 Message Date
Noah Kasmanoff
de3d2467a3
Update: Fast GeLU Approximation (#744)
* add: fast gelu approx

* fix docs

* Update gelu_fast_approx function documentation

* Update python/mlx/nn/layers/activations.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix: test gelu

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-26 21:08:50 -08:00
Awni Hannun
fe1dabf272
Fix compile with non standard types (#745)
* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-26 19:28:53 -08:00
Chime Ogbuji
3b661b7394
Add linear warmup and schedule joining for use with existing schedules (#721)
* Add linear warmup to schedules for use with existing schedules

* Changed parameters for simplicity of most common case (0 initial value)

* Added ScheduleJoiner and updated documentation

* ScheduleJoiner -> join_schedules (ala optax #)

* black compliance

* Different evaluation of schedules

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-26 07:28:48 -08:00
Gabrijel Boduljak
22364c40b7
Upsample2d (#414)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-23 09:55:04 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs

* update compile benchmark

* default compile a few activations

* buffer donation

* bugfix

* shapeless fix

* update tests to work for cpu and gpu fusion

* test kwargs

* add kwargs to compile

* Recompile when python arguments change

* no compile for tanh

* some constant tests

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Srimukh Sripada
818cda16bc
Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
Diogo
35431a4ac8
Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Gabrijel Boduljak
e54cbb7ba6
Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
LeonEricsson
7dccd42133
updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b
Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef
fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Aryan Gupta
ef73393a19
Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
AtomicVar
83f63f2184
Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
David Koski
601c6d6aa8
Fix for AdaDelta (#603)
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365
Change the transformer to norm_first by default (#599) 2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb
Add py.typed to support PEP-561 (type-hinting) for mlx (#588)
* Add `py.typed` to support PEP-561 (type-hinting)

This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/).

* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
nathan
bad67fec37
Added TeX line breaks to mlx.optimizers.Lion docstring (#595)
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing (#541)
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5
Fix SGD implementation (#473) 2024-01-30 15:50:46 -08:00
Awni Hannun
09b9275027
Make shape a tuple (#591)
* shape tuple

* also remove simplify from docs

* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454
Softshrink mapping + op (#552)
* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations (#563)
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
David Koski
874b739f3c
Fix cache key in RoPE (#561) 2024-01-26 13:10:02 -08:00
Hazem Essam
37fc9db82c
Added Adafactor (#415)
* Added adafactor

* Added Adafactor and ran pre-commit

* modified operations

* Added docstrings

* Switched two ops to fix a bug

* added underscore for internal functions and removed the plus sign in the last return statment

* Removed parameter rms from the optimizer state because its not needed

* Added simple MNIST test for Adafactor and temporary training log

* remove test files

* nits in docs

* comment nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137
Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc
Common neural network initializers nn.initializers (#456)
* initial commit: constant, normal, uniform

* identity, glorot and he initializers

* docstrings

* rm file

* nits

* nits

* nits

* testing suite

* docs

* nits in docs

* more docs

* remove unused template

* rename packakge to nn.innit

* docs, receptive field

* more docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
d52383367a
format (#510) 2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d
Add ValuError message for Adamax (#508)
* ValuError message added

* beta errors added

* some corrections and testing

* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Anchen
f6feb61f92
feat: add support for saving safetensors in the save_weights (#497)
* feat: add save safetensors support in module save_weights

* chore: checking missing changes

* Update python/mlx/nn/layers/base.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* chore: update docstring for load_weights

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:19:33 -08:00
AtomicVar
550d4bf7c0
Update binary_cross_entropy function to handle both logits and probabilities (#492) 2024-01-18 19:22:23 -08:00
AtomicVar
d1fef34138
Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 06:44:44 -08:00
Jagrit Digani
78102a47ad
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Chunyang Wen
e3e933c6bc
Add type hint for Module (#412) 2024-01-10 11:23:42 -08:00
Awni Hannun
e9ca65c939
Fix BN stats to not expand shape (#409)
* fix BN stats to not expand shape

* nit
2024-01-09 11:54:51 -08:00
YUN, Junwoo
0b8aeddac6
Additoinal losses (#336)
* cosine similarity loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>

* Docstring nits
2024-01-08 14:01:13 -08:00
Hazem Essam
022a944367
Added GLU activation function and Gated activation function (#329)
* Added GLU activation function and gated activation function

* Ran pre-commit

* Ran pre commit

* Removed old sigmoid implementation to match with main

* Removed gated activation from __init__.py

* Removed unused test cases

* Removed unused imports

* format / docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 06:13:16 -08:00
Angelos Katharopoulos
6ea6b4258d
Fix style check (#395) 2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a
Add theta cache for Rope and mask cache for ALiBi (#375) 2024-01-07 00:22:58 -08:00
Angelos Katharopoulos
75dc537e44
Fix the sigmoid module (#371) 2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5
revert copy (#366) 2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160
Remove useless pass (#364)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-04 06:34:01 -08:00
toji
d2467c320d
Added support for python copy (#335)
* Added support for python copy

* precommit changes

* removed `_compiled_call_impl` line

* added tests and suggested changes

* ACK changes
2024-01-03 20:59:40 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T (#349)
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Gabrijel Boduljak
c7edafb729
implemented InstanceNorm (#244)
* implemented instancenorm

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

* addressed PR review comments

* fixed import order in __init__

* expected values in instancenorm tests are simple lists

* minor return expression style change

* added InstanceNorm to docs

* doc string nits

* added myself to individual contributors

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f
Module checks the weight on load_weights (#337)
* update module to check weights on load, also fix docs and reorganize tests

* nits + rebase

* a few more docs updates for Module

* use manual module file

* comment
2024-01-02 18:55:42 -08:00
Angelos Katharopoulos
436bec9fd9
Fix the implementation of the Bilinear layer (#347) 2024-01-02 16:46:18 -08:00
Asaf Zorea
295ce9db09
Feature expand nn linear (#315)
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00
Josh Soref
44c1ce5e6a
Spelling (#342)
* spelling: accumulates

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: across

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: additional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: against

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: among

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: array

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: available

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: axes

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: basically

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bfloat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: bounds

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: broadcast

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: buffer

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: class

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: coefficients

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: collision

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: combinations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: committing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: computation

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: consider

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: constructing

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: conversions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: correctly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: corresponding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: declaration

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: default

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dependency

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destination

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: destructor

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: dimensions

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: divided

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: element-wise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: elements

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: endianness

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: equivalent

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: explicitly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: github

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: indices

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: irregularly

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: memory

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: metallib

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: negative

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: notable

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: optional

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: otherwise

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: overridden

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partially

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: partition

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perform

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: perturbations

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: positively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: primitive

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeat

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: repeats

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respect

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: respectively

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: result

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: rounding

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: separate

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: skipping

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: structure

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: the

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: transpose

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unnecessary

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unneeded

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

* spelling: unsupported

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>

---------

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Nripesh Niketan
e09bf35b28
feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers

* acknowledgement

* Add dropout tests to test_nn.py

* run pre-commit

* Add activation functions and dropout3d ops

* Add dropout tests for bfloat16 and float16
2023-12-31 14:01:21 -08:00
Hazem Essam
e3b8da2a49
Added implementation for Scaled RoPE. (#261)
* Added scale for RoPE

* Ran pre-commit

* Added RoPE scaling test

* Added docstring for scale parameter

* Modified docstrings
2023-12-31 06:06:01 -08:00
Nripesh Niketan
5ad8fb7268
feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function

* run pre-commit

* Add Softsign activation function

* Add Softsign activation function

* Add documentation for ReLU6, Softplus, and Softsign activations

* Update activation functions in neural network layers

* Add LogSoftmax and Hardswish activations

* run pre-commit

* Update activations.py

* Added acknowledgements

* Fix activation function comments

* Fix activation functions in neural network layers
2023-12-29 11:49:36 -08:00
Chunyang Wen
2aedf3e791
Minor refactor for tree_map and tree_unflatten (#311)
* Minor refact for tree_map and tree_unflatten

* Remove the if statement

---------

Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 20:55:10 -08:00
Chunyang Wen
473b6b43b4
Use defaultdict (#307)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 14:46:13 -08:00
Angelos Katharopoulos
d29770eeaa
Update batchnorm to have the running stats in parameters (#305) 2023-12-28 14:31:10 -08:00
Chunyang Wen
040c3bafab
Add missing f str (#306)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-28 06:09:34 -08:00
Chunyang Wen
05767b026f
Add information for dropout probability (#304)
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2023-12-27 21:51:30 -08:00
YUN, Junwoo
4417e37ede
Transformer fix (#167)
* add transformer with dropout, fix transformer ffm, layernorm order

* precommit changes

* precommit changes

* add docstring, activation, norm_first

* run precommit

* run precommit

* add doctstring

* precommit

* style nits in docs

---------

Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-27 08:48:36 -08:00
__mo_san__
a123c3c7d2
implement-batch-norm-layer (#217)
- Add batch normalization layer

---------

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-25 07:32:53 -08:00
Zach Schillaci
22fee5a383
Remove redundant assert in losses.py (#281) 2023-12-24 08:39:08 -08:00
Vidit Agarwal
acf1721b98
Corrected the example of value_and_grad (#274)
* Corrected the example for mx.value_and_grad

* Reformat through pre-commit/black
2023-12-23 11:06:38 -08:00
Nicholas Santavas
d35fa1db41
Add Hinge, Huber and LogCosh losses (#199) 2023-12-22 10:28:10 -08:00
Justin Deschenaux
e8deca84e0
Add dropout2d (#250) 2023-12-22 08:02:29 -08:00
Hazem Essam
0aa65c7a6b
Added ALiBi implementation (#232) 2023-12-21 14:36:38 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Justin Deschenaux
4912ff3ec2
Add Lion optimizer (#209)
* Add Lion optimizer
* Update acknowledgements also with past contributions
2023-12-20 13:54:58 -08:00
Emircan Erol
e549f84532
Triplet Loss (#211)
* Triplet Loss

* Requested Changes

* Margin to alpha
2023-12-19 12:37:12 -08:00
Juarez Bochi
f4f6e17d45
Fix cross-attention (#210)
* Fix cross-attention

With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer

* Add name to contributors
2023-12-18 12:27:27 -08:00
jojopuppet
18cca64c81
Add smoothed L1 loss and enhancements to cross entropy loss (#166)
* Add smooth_l1_loss
* Add labels moothing for cross entropy loss

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 07:26:21 -08:00
Awni Hannun
ee0c2835c5
Docs updates (#198)
Reorganize NN docs + a few other tidbits.
2023-12-17 13:20:55 -08:00
YUN, Junwoo
eebd7c275d
Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes (#142)
* Add AdaMax, AdaDelta, RMSprop
2023-12-16 21:43:15 -08:00
Awni Hannun
2e02acdc83
add base kwarg to rope (#186) 2023-12-15 16:47:59 -08:00
Víctor Aguilar
f24200db2c
accross -> across (#183) 2023-12-15 13:46:50 -08:00
Awni Hannun
25f70d4ca4
Fix divide types + floor divide (//) (#138)
* divide types

* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
2023-12-11 17:04:07 -08:00
Awni Hannun
69505b4e9b
fixes (#131) 2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44
Add Binary Cross Entropy loss (#122)
* update BCE added tests for it ...

* added binary cross entropy loss to docs

* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Awni Hannun
2d0130f80f
fix loss tests (#118)
* fix loss tests

* use none as default
2023-12-10 10:08:19 -08:00
__mo_san__
c1e1c1443f
Added Adagrad optimizer (#102) 2023-12-10 09:22:39 -08:00
Henry Ansah
68bf1d7867
add nn module for sigmoid activation (#111)
* add nn module for sigmoid activation

* update .gitignore with .cache folder generated by jetbrains fleet ide

* remove .cache folder
2023-12-10 07:00:39 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
__mo_san__
ef7b8756c0
Add tanh activation function (#115)
* added Adagrad optimizer ...

* added Tanh activation function ...

* reformatted file ...

* remove unrelated stuff ...

* Update activations.py
2023-12-09 19:25:38 -08:00
Enoch Kan
0b28399638
added mse_loss, nll_loss and kl_div_loss (#98)
* added mse_loss, nll_loss and kl_div_loss

* fixed axis not defined error in nll_loss

* fixed axis not defined in kl_div_loss

* added tests for mse, nll and kl_div

* modified docstrings and added reduce helper func

* updated docstring in kl_div_loss and moved helper func

* added new kl divergence implementation

* added reduction to test

* updated docstring of kl_div_loss with correct spelling

* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Joe Barrow
ac6dc5d3eb
Adding optional bias param to MultiHeadAttention (#104)
* Adding optional  param to

* Run style-checker
2023-12-09 11:04:28 -08:00
Awni Hannun
2520dbcf0a
add losses to the docs, fix black failur (#92) 2023-12-09 06:06:52 -08:00
Abe Leininger
430bfb4944
Adds Nesterov momentum to SGD (#87) 2023-12-08 23:23:36 -08:00
Kai Ma
cb9e585b8e
Style fix for loss functions (#91)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none

* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484
MLE and L1 loss functions (#88)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none
2023-12-08 20:21:37 -08:00
Joe Barrow
69a24e6a1e
AdamW implementation (#72)
* AdamW implementation without bias correction
* Makes use of the underlying Adam implementation
2023-12-08 14:45:34 -08:00
Zach Schillaci
5b9be57ac3
Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
Zach Schillaci
d11d77e581
Spelling fixes in transformer.py (#59) 2023-12-07 13:32:09 -08:00
rushyam
2e126aeb7e
Feature Addition: Encoder-Decoder Transformer Architecture (#50)
* Implemented decoder-transformer-layer, decoder-transformer  and introduce encoder-decoder transformer

* added relu layer

* add src, tgt, memory mask

---------

Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
2023-12-07 07:37:36 -08:00
Jagrit Digani
2440fe0124
NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
2023-12-06 12:03:47 -08:00
Markus Enzweiler
2ffaee0c0d
Updated default argument for stride to 1 in Conv2d() in the docstring (#22) 2023-12-06 07:17:58 -08:00
Awni Hannun
db487e6b1a format 2023-11-30 11:50:50 -08:00
Awni Hannun
46a39e5b1f copyright + ack 2023-11-30 11:12:53 -08:00
Jagrit Digani
e6306cfee9 jagrit's commit files 2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2 angelos's commit files 2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9 awni's commit files 2023-11-29 10:30:41 -08:00