Awni Hannun
6b4f49fe1c
cleanup stats test
2023-12-25 07:27:14 -08:00
Awni Hannun
865e53fcab
doc nits
2023-12-25 07:07:24 -08:00
__mo_san__
15577cb727
Update __init__.py
2023-12-24 23:14:04 +01:00
m0saan
a1c06b7d46
updated BN implementation to handle input shape as NLC and NWHC^^
2023-12-24 23:10:37 +01:00
m0saan
9bf68814a4
updated BN implementation to handle input shape as NLC and NWHC^^
2023-12-24 23:10:37 +01:00
__mo_san__
28009c9cdb
Update python/mlx/nn/layers/normalization.py
...
Update BatchNorm to support NLC and NHWC input formats
In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over.
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
2023-12-24 23:10:37 +01:00
m0saan
cf5a5a4a01
updated the batch norm doc string ^^
2023-12-24 23:10:37 +01:00
__mo_san__
c68a472b83
Update python/mlx/nn/layers/normalization.py
...
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
2023-12-24 23:10:37 +01:00
m0saan
019a85511c
improve batch norm code ^^
2023-12-24 23:10:37 +01:00
__mo_san__
b444a6a693
Update normalization.py
2023-12-24 23:10:37 +01:00
m0saan
a43b853194
refactored and updated batch norm tests ^^
2023-12-24 23:10:37 +01:00
__mo_san__
8b08f440d9
Update __init__.py
2023-12-24 23:10:34 +01:00
m0saan
82ca771e69
updated BN implementation to be more generic ^^
2023-12-24 23:10:10 +01:00
m0saan
7b0f8bda9c
updated docs and added examples to doc string ^^
2023-12-24 23:10:10 +01:00
m0saan
7ec3cadf98
added test cases for batch norm on 3D input & refactored code ^^
2023-12-24 23:10:10 +01:00
__mo_san__
eca773b62c
Update normalization.py
2023-12-24 23:10:10 +01:00
m0saan
a0b2a34e98
rebasing ...
2023-12-24 23:09:52 +01:00
m0saan
d4bf9a2976
calc running mean and var only when training
2023-12-24 23:08:45 +01:00
m0saan
c3c2fcf41d
update batch norm implementation -> fixed some bug and added support for 3D inputs
2023-12-24 23:08:45 +01:00
m0saan
e9fd1cf02d
update batch norm implementation
2023-12-24 23:08:45 +01:00
__mo_san__
ad53687ae7
Update normalization.py
2023-12-24 23:08:45 +01:00
m0saan
2b617b63bd
implemented batchnorm layer
2023-12-24 23:08:45 +01:00
Zach Schillaci
22fee5a383
Remove redundant assert in losses.py ( #281 )
2023-12-24 08:39:08 -08:00
Daniel Strobusch
7365d142a3
random.uniform must respect dtype, even if lower precision than "low" ( #280 )
...
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Vidit Agarwal
8c3da54c7d
Fix failing test for log cosh loss ( #275 )
...
* fix assert statement in log_cosh_loss
* reformatted by pre-commit black
2023-12-23 16:26:46 -08:00
Vidit Agarwal
acf1721b98
Corrected the example of value_and_grad ( #274 )
...
* Corrected the example for mx.value_and_grad
* Reformat through pre-commit/black
2023-12-23 11:06:38 -08:00
Finn Voorhees
f91f450141
Fix argmax returns documentation ( #263 )
2023-12-22 20:33:17 -08:00
Nicholas Santavas
d35fa1db41
Add Hinge, Huber and LogCosh losses ( #199 )
2023-12-22 10:28:10 -08:00
Justin Deschenaux
e8deca84e0
Add dropout2d ( #250 )
2023-12-22 08:02:29 -08:00
Angelos Katharopoulos
1d053e0d1d
Fix the alibi test that was left unchanged ( #252 )
2023-12-21 14:59:25 -08:00
Hazem Essam
0aa65c7a6b
Added ALiBi implementation ( #232 )
2023-12-21 14:36:38 -08:00
Angelos Katharopoulos
2c7df6795e
Make sure that arrays are freed when saving ( #247 )
2023-12-21 14:08:24 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments ( #235 )
...
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8
Adds C++ and nn quantization utilities ( #230 )
...
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Justin Deschenaux
4912ff3ec2
Add Lion optimizer ( #209 )
...
* Add Lion optimizer
* Update acknowledgements also with past contributions
2023-12-20 13:54:58 -08:00
Awni Hannun
f40d17047d
Indexing bug ( #233 )
...
* fix
* test
2023-12-20 10:44:01 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op ( #228 )
...
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
Diogo
137f55bf28
fail early if readinto does not exist ( #221 )
2023-12-19 13:27:17 -08:00
Emircan Erol
e549f84532
Triplet Loss ( #211 )
...
* Triplet Loss
* Requested Changes
* Margin to alpha
2023-12-19 12:37:12 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation ( #205 )
...
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149
Added linspace ( #181 )
...
* linspace ops support
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Juarez Bochi
f4f6e17d45
Fix cross-attention ( #210 )
...
* Fix cross-attention
With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer
* Add name to contributors
2023-12-18 12:27:27 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive ( #203 )
2023-12-18 11:32:48 -08:00
jojopuppet
18cca64c81
Add smoothed L1 loss and enhancements to cross entropy loss ( #166 )
...
* Add smooth_l1_loss
* Add labels moothing for cross entropy loss
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 07:26:21 -08:00
Cyril Zakka, MD
8eb56beb3a
Added clip function ( #159 )
...
* Added clip
* Added Python bindings
* Formatting
* Added cpp tests
* Added Python tests
* python bindings work
* rebase
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
ee0c2835c5
Docs updates ( #198 )
...
Reorganize NN docs + a few other tidbits.
2023-12-17 13:20:55 -08:00
Awni Hannun
90d04072b7
fix build w/ flatten ( #195 )
2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52
implemented Flatten Module ( #149 )
...
* implemented flatten op
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
YUN, Junwoo
eebd7c275d
Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes ( #142 )
...
* Add AdaMax, AdaDelta, RMSprop
2023-12-16 21:43:15 -08:00
Awni Hannun
104c34f906
setite negative indexing bug ( #189 )
2023-12-16 06:44:47 -08:00