LeonEricsson
7dccd42133
updated calls to use loc &scale ( #643 )
2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b
Compile with capture ( #629 )
...
* Simple kernel generation
* Remove the generate kernel from graph_utils
* fix multi-output with compile
* fuse with stopgrad
* v1 input, output capture in compile
* cleanup tree update with visitor update
* nit
* remove todo
* state for model, optional explicit init and more pure optimizer steps
* move learning rate to state
* add lr to opt state, some fixes in capture
* fix optim
* update tuple of containers as well
* fix stream for compiled output
* rng state for compile
* nit
* updates and comments
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com >
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef
fix sequential with empty modules at end ( #647 )
2024-02-07 13:22:27 -08:00
Aryan Gupta
ef73393a19
Feat: Add weights argument in BCE Loss and tests ( #620 )
2024-02-07 09:39:52 -08:00
AtomicVar
83f63f2184
Add Margin Ranking Loss ( #536 )
2024-02-02 10:57:31 -08:00
David Koski
601c6d6aa8
Fix for AdaDelta ( #603 )
...
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365
Change the transformer to norm_first by default ( #599 )
2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb
Add py.typed
to support PEP-561 (type-hinting) for mlx
( #588 )
...
* Add `py.typed` to support PEP-561 (type-hinting)
This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/ ).
* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
nathan
bad67fec37
Added TeX line breaks to mlx.optimizers.Lion docstring ( #595 )
...
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing ( #541 )
...
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5
Fix SGD implementation ( #473 )
2024-01-30 15:50:46 -08:00
Awni Hannun
09b9275027
Make shape a tuple ( #591 )
...
* shape tuple
* also remove simplify from docs
* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454
Softshrink mapping + op ( #552 )
...
* Added Softshrink mapping + op
* formatting
* docs + nits in docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-30 12:56:28 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations ( #563 )
...
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
David Koski
874b739f3c
Fix cache key in RoPE ( #561 )
2024-01-26 13:10:02 -08:00
Hazem Essam
37fc9db82c
Added Adafactor ( #415 )
...
* Added adafactor
* Added Adafactor and ran pre-commit
* modified operations
* Added docstrings
* Switched two ops to fix a bug
* added underscore for internal functions and removed the plus sign in the last return statment
* Removed parameter rms from the optimizer state because its not needed
* Added simple MNIST test for Adafactor and temporary training log
* remove test files
* nits in docs
* comment nit
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137
Enable cross_entropy loss to handle dense targets ( #517 )
...
* Enable cross_entropy loss to handle dense targets
Dense targets means probabilities or one-hot encodings.
* better shape check of weights
* nits in docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc
Common neural network initializers nn.initializers
( #456 )
...
* initial commit: constant, normal, uniform
* identity, glorot and he initializers
* docstrings
* rm file
* nits
* nits
* nits
* testing suite
* docs
* nits in docs
* more docs
* remove unused template
* rename packakge to nn.innit
* docs, receptive field
* more docs
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-23 06:47:20 -08:00
Awni Hannun
d52383367a
format ( #510 )
2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d
Add ValuError message for Adamax ( #508 )
...
* ValuError message added
* beta errors added
* some corrections and testing
* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Anchen
f6feb61f92
feat: add support for saving safetensors in the save_weights
( #497 )
...
* feat: add save safetensors support in module save_weights
* chore: checking missing changes
* Update python/mlx/nn/layers/base.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
* chore: update docstring for load_weights
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com >
2024-01-19 06:19:33 -08:00
AtomicVar
550d4bf7c0
Update binary_cross_entropy function to handle both logits and probabilities ( #492 )
2024-01-18 19:22:23 -08:00
AtomicVar
d1fef34138
Add Gaussian NLL loss function ( #477 )
...
* Add Gaussian NLL loss function
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-18 06:44:44 -08:00
Jagrit Digani
78102a47ad
Update GEMM ( #424 )
...
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Chunyang Wen
e3e933c6bc
Add type hint for Module ( #412 )
2024-01-10 11:23:42 -08:00
Awni Hannun
e9ca65c939
Fix BN stats to not expand shape ( #409 )
...
* fix BN stats to not expand shape
* nit
2024-01-09 11:54:51 -08:00
YUN, Junwoo
0b8aeddac6
Additoinal losses ( #336 )
...
* cosine similarity loss
---------
Co-authored-by: Awni Hannun <awni@apple.com >
* Docstring nits
2024-01-08 14:01:13 -08:00
Hazem Essam
022a944367
Added GLU activation function and Gated activation function ( #329 )
...
* Added GLU activation function and gated activation function
* Ran pre-commit
* Ran pre commit
* Removed old sigmoid implementation to match with main
* Removed gated activation from __init__.py
* Removed unused test cases
* Removed unused imports
* format / docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-08 06:13:16 -08:00
Angelos Katharopoulos
6ea6b4258d
Fix style check ( #395 )
2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a
Add theta cache for Rope and mask cache for ALiBi ( #375 )
2024-01-07 00:22:58 -08:00
Angelos Katharopoulos
75dc537e44
Fix the sigmoid module ( #371 )
2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5
revert copy ( #366 )
2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160
Remove useless pass ( #364 )
...
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com >
2024-01-04 06:34:01 -08:00
toji
d2467c320d
Added support for python copy ( #335 )
...
* Added support for python copy
* precommit changes
* removed `_compiled_call_impl` line
* added tests and suggested changes
* ACK changes
2024-01-03 20:59:40 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T ( #349 )
...
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Gabrijel Boduljak
c7edafb729
implemented InstanceNorm ( #244 )
...
* implemented instancenorm
* implemented vector_norm in cpp
added linalg to mlx
* implemented vector_norm python binding
* renamed vector_norm to norm, implemented norm without provided ord
* completed the implementation of the norm
* added tests
* removed unused import in linalg.cpp
* updated python bindings
* added some tests for python bindings
* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy
* added better docs and examples
* refactored mlx.linalg.norm bindings
* reused existing util for implementation of linalg.norm
* more tests
* fixed a bug with no ord and axis provided
* removed unused imports
* some style and API consistency updates to linalg norm
* remove unused includes
* fix python tests
* fixed a bug with frobenius norm of a complex-valued matrix
* complex for vector too
* addressed PR review comments
* fixed import order in __init__
* expected values in instancenorm tests are simple lists
* minor return expression style change
* added InstanceNorm to docs
* doc string nits
* added myself to individual contributors
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f
Module checks the weight on load_weights
( #337 )
...
* update module to check weights on load, also fix docs and reorganize tests
* nits + rebase
* a few more docs updates for Module
* use manual module file
* comment
2024-01-02 18:55:42 -08:00
Angelos Katharopoulos
436bec9fd9
Fix the implementation of the Bilinear layer ( #347 )
2024-01-02 16:46:18 -08:00
Asaf Zorea
295ce9db09
Feature expand nn linear ( #315 )
...
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias
* pre-commit run
* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization
* Remove unnecessary reshape
* Added 'i' to bilinear formula
* Changed bilinear computation to two matrix multiplications
* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)
* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00
Josh Soref
44c1ce5e6a
Spelling ( #342 )
...
* spelling: accumulates
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: across
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: additional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: against
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: among
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: array
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: available
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: axes
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: basically
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: bfloat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: bounds
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: buffer
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: class
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: coefficients
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: collision
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: combinations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: committing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: computation
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: consider
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: constructing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: conversions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: correctly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: corresponding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: declaration
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: default
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: dependency
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: destination
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: destructor
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: dimensions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: divided
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: element-wise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: elements
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: endianness
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: equivalent
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: explicitly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: github
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: indices
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: irregularly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: memory
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: metallib
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: negative
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: notable
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: optional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: otherwise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: overridden
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: partially
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: perform
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: perturbations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: positively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: repeat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: repeats
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: respect
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: respectively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: result
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: rounding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: separate
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: skipping
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: structure
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: the
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: transpose
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unnecessary
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unneeded
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
* spelling: unsupported
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
---------
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com >
2024-01-01 21:08:17 -08:00
Nripesh Niketan
e09bf35b28
feat: Add Dropout3d layer to nn.layers ( #313 )
...
* feat: Add Dropout3d layer to nn.layers
* acknowledgement
* Add dropout tests to test_nn.py
* run pre-commit
* Add activation functions and dropout3d ops
* Add dropout tests for bfloat16 and float16
2023-12-31 14:01:21 -08:00
Hazem Essam
e3b8da2a49
Added implementation for Scaled RoPE. ( #261 )
...
* Added scale for RoPE
* Ran pre-commit
* Added RoPE scaling test
* Added docstring for scale parameter
* Modified docstrings
2023-12-31 06:06:01 -08:00
Nripesh Niketan
5ad8fb7268
feat: add softsign, softmax, hardswish, logsoftmax activation function ( #309 )
...
* feat: add softsign activation function
* run pre-commit
* Add Softsign activation function
* Add Softsign activation function
* Add documentation for ReLU6, Softplus, and Softsign activations
* Update activation functions in neural network layers
* Add LogSoftmax and Hardswish activations
* run pre-commit
* Update activations.py
* Added acknowledgements
* Fix activation function comments
* Fix activation functions in neural network layers
2023-12-29 11:49:36 -08:00
Chunyang Wen
2aedf3e791
Minor refactor for tree_map and tree_unflatten ( #311 )
...
* Minor refact for tree_map and tree_unflatten
* Remove the if statement
---------
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com >
2023-12-28 20:55:10 -08:00
Chunyang Wen
473b6b43b4
Use defaultdict ( #307 )
...
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com >
2023-12-28 14:46:13 -08:00
Angelos Katharopoulos
d29770eeaa
Update batchnorm to have the running stats in parameters ( #305 )
2023-12-28 14:31:10 -08:00
Chunyang Wen
040c3bafab
Add missing f str ( #306 )
...
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com >
2023-12-28 06:09:34 -08:00
Chunyang Wen
05767b026f
Add information for dropout probability ( #304 )
...
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com >
2023-12-27 21:51:30 -08:00
YUN, Junwoo
4417e37ede
Transformer fix ( #167 )
...
* add transformer with dropout, fix transformer ffm, layernorm order
* precommit changes
* precommit changes
* add docstring, activation, norm_first
* run precommit
* run precommit
* add doctstring
* precommit
* style nits in docs
---------
Co-authored-by: junwoo-yun <junwoo.yun@bagelcode.com >
Co-authored-by: Awni Hannun <awni@apple.com >
2023-12-27 08:48:36 -08:00
__mo_san__
a123c3c7d2
implement-batch-norm-layer ( #217 )
...
- Add batch normalization layer
---------
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com >
Co-authored-by: Awni Hannun <awni@apple.com >
2023-12-25 07:32:53 -08:00