Awni Hannun
c6739ba7f3
Faster RNN layers ( #1419 )
...
* faster rnn
* use admm
2024-09-17 06:04:19 -07:00
Angelos Katharopoulos
914409fef9
Data parallel helper ( #1407 )
2024-09-16 18:17:21 -07:00
Awni Hannun
d5ed4d7a71
override class function ( #1418 )
2024-09-16 13:21:04 -07:00
c0g
bd8396fad8
Fix typo in transformer docs ( #1414 )
2024-09-14 06:05:15 -07:00
Awni Hannun
8b30acd7eb
fix module attribute set, reset, set ( #1403 )
2024-09-11 16:30:42 -07:00
Max-Heinrich Laves
efeb9c0f02
Transposed Convolution ( #1245 )
...
* initial implementation for conv_transpose
ran pre-commit
implemented conv_transpose
updated conv_general docstring
updated conv_general docstring
updated code comments
removed commented run_conv_checks
updated acknowledgments
added missing entry to ops.rst
added op to nn.layers
resolved merge conflicts
* removed ConvolutionTranspose primitive as suggested by reviewer
removed ConvolutionTranspose primitive as suggested by reviewer
* remove transpose flag, add another test
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-06 19:52:38 -07:00
Saanidhya
4e22a1dffe
In continuation to PR1243 to solve issue #1240 ( #1365 )
...
* Solves issue #1240
* Correction
* Update python/mlx/utils.py
* Update python/mlx/utils.py
---------
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-08-28 11:40:41 -07:00
Awni Hannun
291cf40aca
Some fixes to typing ( #1371 )
...
* some fixes to typing
* fix module reference
* comment
2024-08-28 11:16:19 -07:00
Awni Hannun
ae5b5cabfd
Fix optimizer reloading from checkpoint ( #1329 )
...
* fix optimizer reloading from checkpoint
* comment
2024-08-15 07:33:23 -07:00
Awni Hannun
63ae767232
fix transformer ( #1327 )
2024-08-13 16:04:26 -07:00
Bhargav Yagnik
a098bc92e0
Fix: Preserve input dtype in Dropout layer output ( #1323 )
...
* Fix: Preserve input dtype in Dropout layer output
- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321
* Update test cases in test_nn.py
- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,
* updated dropout.py
2024-08-13 11:54:21 -07:00
Alex Barron
635ccd9e25
Add "edge" mode to mx.pad ( #1309 )
...
* Add edge padding mode
* fix pad in pooling
* string arg instead of enum
2024-08-06 11:23:10 -07:00
Awni Hannun
6c8dd307eb
faster group norm ( #1304 )
2024-08-01 12:49:23 -07:00
Atakan Tekparmak
6e06e3a904
feat: Added "tanh" option to GELU approximation ( #1268 )
2024-07-28 09:07:56 +02:00
Paul Paczuski
ebd7135b50
Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 ( #1280 )
...
* Improve stability of BCE loss calculation
* Standardize comment
* Apply formatting with black via pre-commit
* Add usage recommendation to docstring
* Update python/mlx/nn/losses.py
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
toji
6768c6a54a
Adding missing type hints ( #1243 )
...
* added type hints for `run`, `tree_map` and `tree_map_with_path`
* fix lint
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Awni Hannun
8c01a7893b
minor fix in optimizer + docs ( #1264 )
2024-07-12 12:18:02 -07:00
Awni Hannun
20bb301195
CPU binary reduction + Nits ( #1242 )
...
* very minor nits
* reduce binary
* fix test
2024-06-28 13:50:42 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh ( #1180 )
...
---------
Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Dominik Schlösser
3576b547c5
Doc error for default for scale in SinusoidalPositionalEncoding ( #1174 )
2024-06-02 13:42:45 -07:00
Awni Hannun
e6fecbb3e1
Some fixes in docs ( #1141 )
...
* fixes in docs
* nit
2024-05-20 11:51:47 -07:00
jlwitthuhn
7e5674d8be
Treate 'minimum' differently in cosine decay ( #1138 )
2024-05-20 08:00:48 -07:00
Angelos Katharopoulos
e78a6518fa
Block sparse qmm ( #1124 )
2024-05-16 15:24:14 -07:00
Cheng
5be5daa6ef
Use compiled function in Sigmoid module ( #1116 )
2024-05-14 06:25:57 -07:00
Cheng
60cb11764e
Use correct module type in quantized.py ( #1115 )
2024-05-14 06:25:42 -07:00
Max-Heinrich Laves
ff4223904d
Conv3d ( #993 )
...
* added conv3d
added conv3d
implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D
* incorporated reviewer comments
* fixed test
* reduced tensor shapes in test for conv3d
* Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Reviewer suggestion
2024-05-11 06:15:02 -07:00
Nripesh Niketan
79c859e2e0
feat: implement clip_grad_norm
( #1043 )
...
* feat: implement `clip_grad_norm`
* pre-commit
* Add test for clip_grad_norm function in test_optimizers.py
* small fixes
* fix
* lint
* Update tree_reduce
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/utils.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Refactor clip_grad_norm function to include documentation and improve readability
* format docstring
* Add acknowlegements
* text wrap
* pre-commit
* nits in docs
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Piotr Rybiec
581b699ac9
avgpool, not maxpool ( #1002 )
2024-04-17 08:26:22 -07:00
Shiyu
107ba2891a
gelu tanh approx ( #989 )
...
* gelu tanh approx
* gelu tanh approx
* replace gelu approx with tanh approach
* fix comments
* fix comment
2024-04-15 19:49:00 -07:00
Awni Hannun
cd9e184529
Quantize embedding ( #994 )
...
* quantize embedding
* rename as_linear + comment
* consistency in docs
* fix test
2024-04-15 16:42:10 -07:00
Shiyu
061cf9a4ce
Upsample with bicubic interpolation ( #967 )
2024-04-10 15:47:22 -07:00
Awni Hannun
741eb28443
fix a couple bugs ( #952 )
2024-04-02 12:07:41 -07:00
AmirHossein_Razlighi
d611251502
Support Chaining for some of functionalities of nn.Module
( #885 ) ( #897 )
...
* add chaining support for some of the functionalities of "nn.Module"
* reformat
* change the return types
* remove return types
* add return type with forward referencing
* add tests for chaining
* add name to contributors
* Update python/mlx/nn/layers/base.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/mlx/nn/layers/base.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* update docstring
* update docstrings
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-27 19:58:29 -07:00
Awni Hannun
570f2bf29e
pick up preivously set attributes ( #905 )
2024-03-26 11:19:59 -07:00
Daniel Strobusch
479051ce1c
add numeric type hierarchy and issubdtype as well as a set_dtype meth… ( #427 )
...
* add numeric type hierarchy and issubdtype as well as a set_dtype method to nn.Module with predicate
numeric type hierarchy and issubtype is compatible to the [numpy hierarchy](220f0ab2c5/numpy/_core/numerictypes.py (L42)
).
Closes #285 .
* nits in docs
* unify type category checking
* nits in docs
* nits in docs
* more docs nits
* fix callable type
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-25 12:32:59 -07:00
Awni Hannun
1e16331d9c
post nanobind docs fixes and some updates ( #889 )
...
* post nanobind docs fixes and some updates
* one more doc nit
* fix for stubs and latex
2024-03-24 15:03:27 -07:00
Awni Hannun
be98f4ab6b
Reduce a little overhead ( #871 )
...
* some small overhead improvements
* use result_type in rms_norm
* remove release force
* fix + use non-vector version
* revert compile change
* fix ops
* a little more overhead
* a little more cleanup and overhead
2024-03-22 17:29:36 -07:00
Angelos Katharopoulos
2225374060
Adds mx.fast.layer_norm ( #870 )
2024-03-21 13:55:51 -07:00
Angelos Katharopoulos
53e6a9367c
Use reshape and transpose for non-overlapping pooling windows ( #867 )
2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8
Add minimum for cosine decay function ( #859 )
...
* Add minimum for cosine decay function
* Update python/mlx/optimizers/schedulers.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f
Fast RMS Norm ( #862 )
...
* fast rmsnorm
* no rms gpu
* kernel
* fix shared mem
* looped rms and donation in softmax
* Make the squaring in float32 to avoid underflow
* Fix the default StreamOrDevice for rope and rms_norm in fast
* nits
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Awni Hannun
16546c70d8
No reshape rope ( #838 )
...
* no reshape rope
* no reshape rope
2024-03-18 17:03:07 -07:00
Awni Hannun
366478c560
fix modules with dict ( #819 )
2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a
Implement RNN, GRU, LSTM ( #268 )
...
* RNN base implementation
* Address comments+format
* nits in docs
* add tests for prb
* fix test
* add a couple tests
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
28301807c2
Version bump and os error ( #807 )
2024-03-07 13:57:58 -08:00
Awni Hannun
cbcf44a4ca
Some fixes in cache / thread safety ( #777 )
...
* some fixes in cache / thread safety
* speed up no cache case
* fix opt test
* optimizer docs
* otpimizer docs
* fix adafactor
* fix adafactor
2024-03-05 13:30:50 -08:00
Piotr Rybiec
6a665ea6ed
Dilation for convolutional layers ( #766 )
...
* add dilation parameter to Conv1d layer
* space here too
* add conv1d dilation test
* add dilation parameter for Conv2d layer
* conv2d dilation test
2024-03-04 06:43:00 -08:00
Awni Hannun
bc06cb9ff6
Pickle + dtype fix for numpy conversion ( #763 )
...
* pickle + dtype fix for numpy conversion
* fix getattribute on Module base
* remove unused function
* fix tests
* add topk to ops
* fix doc
2024-03-02 06:09:29 -08:00
Awni Hannun
4494970f47
avoid nested closures in module ( #759 )
2024-02-29 09:39:52 -08:00
Awni Hannun
420ff2f331
Add back compiled function signatures and docstrings ( #749 )
...
* try to add back compiled function signatures and docstrings
* add indentation to docstring
2024-02-27 13:18:59 -08:00
Noah Kasmanoff
de3d2467a3
Update: Fast GeLU Approximation ( #744 )
...
* add: fast gelu approx
* fix docs
* Update gelu_fast_approx function documentation
* Update python/mlx/nn/layers/activations.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* fix: test gelu
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-26 21:08:50 -08:00
Awni Hannun
fe1dabf272
Fix compile with non standard types ( #745 )
...
* refactor tree utils
* fix compile + tree code refactor
* Add an extra test
* add a few missing activations to docs
* hash structure
* Encode the full argument structure
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-26 19:28:53 -08:00
Chime Ogbuji
3b661b7394
Add linear warmup and schedule joining for use with existing schedules ( #721 )
...
* Add linear warmup to schedules for use with existing schedules
* Changed parameters for simplicity of most common case (0 initial value)
* Added ScheduleJoiner and updated documentation
* ScheduleJoiner -> join_schedules (ala optax #)
* black compliance
* Different evaluation of schedules
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-26 07:28:48 -08:00
Gabrijel Boduljak
22364c40b7
Upsample2d ( #414 )
...
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-23 09:55:04 -08:00
Awni Hannun
5798256fcf
Shapeless compilation for some graphs ( #687 )
...
* shapeless compilation for some graphs
* update compile benchmark
* default compile a few activations
* buffer donation
* bugfix
* shapeless fix
* update tests to work for cpu and gpu fusion
* test kwargs
* add kwargs to compile
* Recompile when python arguments change
* no compile for tanh
* some constant tests
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-19 21:43:54 -08:00
Srimukh Sripada
818cda16bc
Support LR schedulers ( #334 )
...
* Add a few LR schedulers
* Move parents's constructor call to the top
* Fix docstring
* refactor optimizers into two files
* add docs
* nit
* Fix Callable type annotation for python 3.8
---------
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
Diogo
35431a4ac8
Adds device context manager ( #679 )
2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995
Custom primitive + RoPE fat op ( #676 )
...
* extensions start
* rope custom op
* fix build
* docs + rope benchmark
* fix test
* Add a Metal kernel for RoPE
* Fix position of traditional
* transform tests
* Move rope computation to float and fix tests
* Fix the test and a typo
* change to fast
* fix no metal build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Gabrijel Boduljak
e54cbb7ba6
Pooling layers ( #357 )
...
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
LeonEricsson
7dccd42133
updated calls to use loc &scale ( #643 )
2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b
Compile with capture ( #629 )
...
* Simple kernel generation
* Remove the generate kernel from graph_utils
* fix multi-output with compile
* fuse with stopgrad
* v1 input, output capture in compile
* cleanup tree update with visitor update
* nit
* remove todo
* state for model, optional explicit init and more pure optimizer steps
* move learning rate to state
* add lr to opt state, some fixes in capture
* fix optim
* update tuple of containers as well
* fix stream for compiled output
* rng state for compile
* nit
* updates and comments
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef
fix sequential with empty modules at end ( #647 )
2024-02-07 13:22:27 -08:00
Aryan Gupta
ef73393a19
Feat: Add weights argument in BCE Loss and tests ( #620 )
2024-02-07 09:39:52 -08:00
AtomicVar
83f63f2184
Add Margin Ranking Loss ( #536 )
2024-02-02 10:57:31 -08:00
David Koski
601c6d6aa8
Fix for AdaDelta ( #603 )
...
- state was being read from parameter "s"
- but being stored in parameter "u"
2024-02-01 09:56:27 -08:00
Angelos Katharopoulos
ba8d6bf365
Change the transformer to norm_first by default ( #599 )
2024-01-31 12:55:30 -08:00
Sugato Ray
4a5f3b21bb
Add py.typed
to support PEP-561 (type-hinting) for mlx
( #588 )
...
* Add `py.typed` to support PEP-561 (type-hinting)
This adds support for type-hinting information as laid in [PEP-561](https://peps.python.org/pep-0561/ ).
* add py.typed to MANIFEST.in
2024-01-31 12:05:42 -08:00
nathan
bad67fec37
Added TeX line breaks to mlx.optimizers.Lion docstring ( #595 )
...
Fixes the "misplaced &" MathJax error in documentation.
2024-01-30 19:37:34 -08:00
Angelos Katharopoulos
0de5988f92
Custom VJP and checkpointing ( #541 )
...
* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
2024-01-30 16:04:45 -08:00
Jacket
143e2690d5
Fix SGD implementation ( #473 )
2024-01-30 15:50:46 -08:00
Awni Hannun
09b9275027
Make shape a tuple ( #591 )
...
* shape tuple
* also remove simplify from docs
* rebase
2024-01-30 13:11:01 -08:00
Andre Slavescu
d3a9005454
Softshrink mapping + op ( #552 )
...
* Added Softshrink mapping + op
* formatting
* docs + nits in docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-30 12:56:28 -08:00
Jagrit Digani
bf17ab5002
Add more checks and clearer error messages to conv operations ( #563 )
...
* Add more checks and clearer error messages to conv operations
2024-01-26 15:13:26 -08:00
David Koski
874b739f3c
Fix cache key in RoPE ( #561 )
2024-01-26 13:10:02 -08:00
Hazem Essam
37fc9db82c
Added Adafactor ( #415 )
...
* Added adafactor
* Added Adafactor and ran pre-commit
* modified operations
* Added docstrings
* Switched two ops to fix a bug
* added underscore for internal functions and removed the plus sign in the last return statment
* Removed parameter rms from the optimizer state because its not needed
* Added simple MNIST test for Adafactor and temporary training log
* remove test files
* nits in docs
* comment nit
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 15:11:27 -08:00
AtomicVar
755dcf6137
Enable cross_entropy loss to handle dense targets ( #517 )
...
* Enable cross_entropy loss to handle dense targets
Dense targets means probabilities or one-hot encodings.
* better shape check of weights
* nits in docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 12:17:22 -08:00
LeonEricsson
6b4b30e3fc
Common neural network initializers nn.initializers
( #456 )
...
* initial commit: constant, normal, uniform
* identity, glorot and he initializers
* docstrings
* rm file
* nits
* nits
* nits
* testing suite
* docs
* nits in docs
* more docs
* remove unused template
* rename packakge to nn.innit
* docs, receptive field
* more docs
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-23 06:47:20 -08:00
Awni Hannun
d52383367a
format ( #510 )
2024-01-20 10:33:46 -08:00
Arda Orçun
363d3add6d
Add ValuError message for Adamax ( #508 )
...
* ValuError message added
* beta errors added
* some corrections and testing
* Learning rate limitation deleted
2024-01-20 07:56:15 -08:00
Anchen
f6feb61f92
feat: add support for saving safetensors in the save_weights
( #497 )
...
* feat: add save safetensors support in module save_weights
* chore: checking missing changes
* Update python/mlx/nn/layers/base.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* chore: update docstring for load_weights
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-01-19 06:19:33 -08:00
AtomicVar
550d4bf7c0
Update binary_cross_entropy function to handle both logits and probabilities ( #492 )
2024-01-18 19:22:23 -08:00
AtomicVar
d1fef34138
Add Gaussian NLL loss function ( #477 )
...
* Add Gaussian NLL loss function
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-18 06:44:44 -08:00
Jagrit Digani
78102a47ad
Update GEMM ( #424 )
...
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition
* Update tests and benchmarks as needed
2024-01-17 12:42:39 -08:00
Chunyang Wen
e3e933c6bc
Add type hint for Module ( #412 )
2024-01-10 11:23:42 -08:00
Awni Hannun
e9ca65c939
Fix BN stats to not expand shape ( #409 )
...
* fix BN stats to not expand shape
* nit
2024-01-09 11:54:51 -08:00
YUN, Junwoo
0b8aeddac6
Additoinal losses ( #336 )
...
* cosine similarity loss
---------
Co-authored-by: Awni Hannun <awni@apple.com>
* Docstring nits
2024-01-08 14:01:13 -08:00
Hazem Essam
022a944367
Added GLU activation function and Gated activation function ( #329 )
...
* Added GLU activation function and gated activation function
* Ran pre-commit
* Ran pre commit
* Removed old sigmoid implementation to match with main
* Removed gated activation from __init__.py
* Removed unused test cases
* Removed unused imports
* format / docstring
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-08 06:13:16 -08:00
Angelos Katharopoulos
6ea6b4258d
Fix style check ( #395 )
2024-01-07 05:54:58 -08:00
Anchen
48f6ca8c3a
Add theta cache for Rope and mask cache for ALiBi ( #375 )
2024-01-07 00:22:58 -08:00
Angelos Katharopoulos
75dc537e44
Fix the sigmoid module ( #371 )
2024-01-04 13:16:36 -08:00
Awni Hannun
cf88db44b5
revert copy ( #366 )
2024-01-04 10:43:29 -08:00
Chunyang Wen
16856a0160
Remove useless pass ( #364 )
...
Co-authored-by: Chunyang Wen <chunyang_wen@apple.com>
2024-01-04 06:34:01 -08:00
toji
d2467c320d
Added support for python copy ( #335 )
...
* Added support for python copy
* precommit changes
* removed `_compiled_call_impl` line
* added tests and suggested changes
* ACK changes
2024-01-03 20:59:40 -08:00
Angelos Katharopoulos
e7f5059fe4
Support for quantized matmul with w and w^T ( #349 )
...
* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
2024-01-03 14:22:36 -08:00
Gabrijel Boduljak
c7edafb729
implemented InstanceNorm ( #244 )
...
* implemented instancenorm
* implemented vector_norm in cpp
added linalg to mlx
* implemented vector_norm python binding
* renamed vector_norm to norm, implemented norm without provided ord
* completed the implementation of the norm
* added tests
* removed unused import in linalg.cpp
* updated python bindings
* added some tests for python bindings
* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy
* added better docs and examples
* refactored mlx.linalg.norm bindings
* reused existing util for implementation of linalg.norm
* more tests
* fixed a bug with no ord and axis provided
* removed unused imports
* some style and API consistency updates to linalg norm
* remove unused includes
* fix python tests
* fixed a bug with frobenius norm of a complex-valued matrix
* complex for vector too
* addressed PR review comments
* fixed import order in __init__
* expected values in instancenorm tests are simple lists
* minor return expression style change
* added InstanceNorm to docs
* doc string nits
* added myself to individual contributors
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-01-03 12:21:15 -08:00
Awni Hannun
dff4a3833f
Module checks the weight on load_weights
( #337 )
...
* update module to check weights on load, also fix docs and reorganize tests
* nits + rebase
* a few more docs updates for Module
* use manual module file
* comment
2024-01-02 18:55:42 -08:00
Angelos Katharopoulos
436bec9fd9
Fix the implementation of the Bilinear layer ( #347 )
2024-01-02 16:46:18 -08:00
Asaf Zorea
295ce9db09
Feature expand nn linear ( #315 )
...
* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias
* pre-commit run
* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization
* Remove unnecessary reshape
* Added 'i' to bilinear formula
* Changed bilinear computation to two matrix multiplications
* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)
* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
2024-01-02 06:08:53 -08:00
Josh Soref
44c1ce5e6a
Spelling ( #342 )
...
* spelling: accumulates
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: across
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: additional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: against
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: among
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: array
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: available
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: axes
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: basically
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bfloat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: bounds
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: buffer
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: class
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: coefficients
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: collision
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: combinations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: committing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: computation
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: consider
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: constructing
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: conversions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: correctly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: corresponding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: declaration
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: default
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dependency
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destination
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: destructor
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: dimensions
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: divided
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: element-wise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: elements
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: endianness
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: equivalent
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: explicitly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: github
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: indices
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: irregularly
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: memory
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: metallib
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: negative
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: notable
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: optional
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: otherwise
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: overridden
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partially
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perform
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: perturbations
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: positively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: primitive
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeat
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: repeats
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respect
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: respectively
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: result
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: rounding
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: separate
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: skipping
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: structure
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: the
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: transpose
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unnecessary
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unneeded
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
* spelling: unsupported
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
---------
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
2024-01-01 21:08:17 -08:00
Nripesh Niketan
e09bf35b28
feat: Add Dropout3d layer to nn.layers ( #313 )
...
* feat: Add Dropout3d layer to nn.layers
* acknowledgement
* Add dropout tests to test_nn.py
* run pre-commit
* Add activation functions and dropout3d ops
* Add dropout tests for bfloat16 and float16
2023-12-31 14:01:21 -08:00