m0saan
a1c06b7d46
updated BN implementation to handle input shape as NLC and NWHC^^
2023-12-24 23:10:37 +01:00
m0saan
9bf68814a4
updated BN implementation to handle input shape as NLC and NWHC^^
2023-12-24 23:10:37 +01:00
__mo_san__
28009c9cdb
Update python/mlx/nn/layers/normalization.py
...
Update BatchNorm to support NLC and NHWC input formats
In our convolution operations, we follow the convention that the channels are the last dimension. This commit updates the BatchNorm implementation to support inputs where the channels are the last dimension (NLC or NHWC). This involves changing the dimensions we normalize over and the dimensions we expand our parameters over.
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
2023-12-24 23:10:37 +01:00
m0saan
cf5a5a4a01
updated the batch norm doc string ^^
2023-12-24 23:10:37 +01:00
__mo_san__
c68a472b83
Update python/mlx/nn/layers/normalization.py
...
Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
2023-12-24 23:10:37 +01:00
m0saan
019a85511c
improve batch norm code ^^
2023-12-24 23:10:37 +01:00
__mo_san__
b444a6a693
Update normalization.py
2023-12-24 23:10:37 +01:00
m0saan
a43b853194
refactored and updated batch norm tests ^^
2023-12-24 23:10:37 +01:00
__mo_san__
8b08f440d9
Update __init__.py
2023-12-24 23:10:34 +01:00
__mo_san__
02ce72d4cd
Update layers.rst
2023-12-24 23:10:10 +01:00
m0saan
82ca771e69
updated BN implementation to be more generic ^^
2023-12-24 23:10:10 +01:00
m0saan
7b0f8bda9c
updated docs and added examples to doc string ^^
2023-12-24 23:10:10 +01:00
m0saan
7ec3cadf98
added test cases for batch norm on 3D input & refactored code ^^
2023-12-24 23:10:10 +01:00
__mo_san__
eca773b62c
Update normalization.py
2023-12-24 23:10:10 +01:00
m0saan
a0b2a34e98
rebasing ...
2023-12-24 23:09:52 +01:00
m0saan
d4bf9a2976
calc running mean and var only when training
2023-12-24 23:08:45 +01:00
m0saan
c3c2fcf41d
update batch norm implementation -> fixed some bug and added support for 3D inputs
2023-12-24 23:08:45 +01:00
m0saan
e9fd1cf02d
update batch norm implementation
2023-12-24 23:08:45 +01:00
__mo_san__
ad53687ae7
Update normalization.py
2023-12-24 23:08:45 +01:00
m0saan
2b617b63bd
implemented batchnorm layer
2023-12-24 23:08:45 +01:00
Zach Schillaci
22fee5a383
Remove redundant assert in losses.py ( #281 )
2023-12-24 08:39:08 -08:00
Daniel Strobusch
7365d142a3
random.uniform must respect dtype, even if lower precision than "low" ( #280 )
...
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Awni Hannun
8b227fa9af
fix no metal build ( #276 )
2023-12-23 19:18:10 -08:00
Vidit Agarwal
8c3da54c7d
Fix failing test for log cosh loss ( #275 )
...
* fix assert statement in log_cosh_loss
* reformatted by pre-commit black
2023-12-23 16:26:46 -08:00
Vidit Agarwal
acf1721b98
Corrected the example of value_and_grad ( #274 )
...
* Corrected the example for mx.value_and_grad
* Reformat through pre-commit/black
2023-12-23 11:06:38 -08:00
Finn Voorhees
f91f450141
Fix argmax returns documentation ( #263 )
2023-12-22 20:33:17 -08:00
Ronan Collobert
cd3616a463
Revisit autorelease memory pools ( #260 )
...
* make general autorelease pool part of metal device
* make things simpler
* no metal backend support
* new_memory_pool -> new_scoped_memory_pool
2023-12-22 11:01:26 -08:00
Nicholas Santavas
d35fa1db41
Add Hinge, Huber and LogCosh losses ( #199 )
2023-12-22 10:28:10 -08:00
Justin Deschenaux
e8deca84e0
Add dropout2d ( #250 )
2023-12-22 08:02:29 -08:00
Angelos Katharopoulos
8385f93cea
Bumping the version ( #256 )
2023-12-21 18:33:14 -08:00
Awni Hannun
2118c3dbfa
fix ( #255 )
2023-12-21 18:18:41 -08:00
Awni Hannun
a002797d52
A temporary fix ( #254 )
2023-12-21 17:59:15 -08:00
Angelos Katharopoulos
1d053e0d1d
Fix the alibi test that was left unchanged ( #252 )
2023-12-21 14:59:25 -08:00
Hazem Essam
0aa65c7a6b
Added ALiBi implementation ( #232 )
2023-12-21 14:36:38 -08:00
Daniel Strobusch
794feb83df
support arange for bfloat16 ( #245 )
2023-12-21 14:33:43 -08:00
Angelos Katharopoulos
2c7df6795e
Make sure that arrays are freed when saving ( #247 )
2023-12-21 14:08:24 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments ( #235 )
...
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8
Adds C++ and nn quantization utilities ( #230 )
...
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Justin Deschenaux
4912ff3ec2
Add Lion optimizer ( #209 )
...
* Add Lion optimizer
* Update acknowledgements also with past contributions
2023-12-20 13:54:58 -08:00
Awni Hannun
f40d17047d
Indexing bug ( #233 )
...
* fix
* test
2023-12-20 10:44:01 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op ( #228 )
...
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
davidkoski
de892cb66c
fix for non-macos build issue on cblas.h ( #227 )
2023-12-19 17:01:59 -08:00
davidkoski
37024d899c
fixes for building with swiftpm ( #225 )
...
- clbas is part of veclib (compile failure)
- add SWIFTPM_BUNDLE #define to allow loading the metallib from a swiftpm resource bundle
2023-12-19 16:22:10 -08:00
Diogo
137f55bf28
fail early if readinto does not exist ( #221 )
2023-12-19 13:27:17 -08:00
Emircan Erol
e549f84532
Triplet Loss ( #211 )
...
* Triplet Loss
* Requested Changes
* Margin to alpha
2023-12-19 12:37:12 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation ( #205 )
...
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149
Added linspace ( #181 )
...
* linspace ops support
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Juarez Bochi
f4f6e17d45
Fix cross-attention ( #210 )
...
* Fix cross-attention
With the current code, ln2 is a no-op. Its output should be passed to the cross-attention layer
* Add name to contributors
2023-12-18 12:27:27 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive ( #203 )
2023-12-18 11:32:48 -08:00
Awni Hannun
477397bc98
Citation + Contributor acknowledgment section ( #207 )
...
* cite
* nits
* nits
* comment
2023-12-18 10:07:00 -08:00