Daniel Strobusch
7365d142a3
random.uniform must respect dtype, even if lower precision than "low" ( #280 )
...
Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array.
2023-12-24 07:04:43 -08:00
Vidit Agarwal
8c3da54c7d
Fix failing test for log cosh loss ( #275 )
...
* fix assert statement in log_cosh_loss
* reformatted by pre-commit black
2023-12-23 16:26:46 -08:00
Nicholas Santavas
d35fa1db41
Add Hinge, Huber and LogCosh losses ( #199 )
2023-12-22 10:28:10 -08:00
Angelos Katharopoulos
1d053e0d1d
Fix the alibi test that was left unchanged ( #252 )
2023-12-21 14:59:25 -08:00
Hazem Essam
0aa65c7a6b
Added ALiBi implementation ( #232 )
2023-12-21 14:36:38 -08:00
Angelos Katharopoulos
b3916cbf2b
Improve names of quantization arguments ( #235 )
...
* Change the default quantization group_size to 64
* Rename groups to group_size and width to bits
2023-12-20 16:53:53 -08:00
Angelos Katharopoulos
57fe918cf8
Adds C++ and nn quantization utilities ( #230 )
...
* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
2023-12-20 14:17:38 -08:00
Awni Hannun
f40d17047d
Indexing bug ( #233 )
...
* fix
* test
2023-12-20 10:44:01 -08:00
Angelos Katharopoulos
2807c6aff0
Implements divide for integer types and adds floor_divide op ( #228 )
...
* Add floor_divide
* Add floor_divide to the tests
* Add floor_divide to the docs
2023-12-19 20:12:19 -08:00
Emircan Erol
e549f84532
Triplet Loss ( #211 )
...
* Triplet Loss
* Requested Changes
* Margin to alpha
2023-12-19 12:37:12 -08:00
Angelos Katharopoulos
dfa9f4bc58
An initial quantized matmul implementation ( #205 )
...
* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
2023-12-18 23:18:57 -08:00
Abe Leininger
e6872a4149
Added linspace ( #181 )
...
* linspace ops support
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 19:57:55 -08:00
Angelos Katharopoulos
4d4af12c6f
Adds round op and primitive ( #203 )
2023-12-18 11:32:48 -08:00
jojopuppet
18cca64c81
Add smoothed L1 loss and enhancements to cross entropy loss ( #166 )
...
* Add smooth_l1_loss
* Add labels moothing for cross entropy loss
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 07:26:21 -08:00
Cyril Zakka, MD
8eb56beb3a
Added clip function ( #159 )
...
* Added clip
* Added Python bindings
* Formatting
* Added cpp tests
* Added Python tests
* python bindings work
* rebase
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
90d04072b7
fix build w/ flatten ( #195 )
2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52
implemented Flatten Module ( #149 )
...
* implemented flatten op
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
Awni Hannun
104c34f906
setite negative indexing bug ( #189 )
2023-12-16 06:44:47 -08:00
Diogo
dc2edc762c
added tri / tril / triu ( #170 )
...
* added tri / tril / triu
* fixed tests
* ctest tests
* tri overload and simplified tests
* changes from comment
* more tests for m
* ensure assert if not 2-D
* remove broadcast_to
* minor tweaks
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-15 17:30:34 -08:00
Awni Hannun
2e02acdc83
add base kwarg to rope ( #186 )
2023-12-15 16:47:59 -08:00
Jason
e28b57e371
Added mx.stack c++ frontend impl ( #123 )
...
* stack C++ operation + python bindings
2023-12-14 13:21:19 -08:00
Awni Hannun
e5851e52b1
Add move and swap axis, and vmap for slice, concat, and gather ( #158 )
...
* add move and swap axis, and vmap for slice, concat, and gather
2023-12-14 12:59:12 -08:00
Luca Arnaboldi
b93c4cf378
Floor and Ceil ( #150 )
...
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Awni Hannun
25f70d4ca4
Fix divide types + floor divide (//) ( #138 )
...
* divide types
* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0
Activations LeakyReLU / PReLU / Softplus / Mish ( #109 )
...
* Leaky_relu / prelu / softplus / mish
* added tests
* updated bench
* remove torch refs, add init to PReLU
* added arvix reference to mish
* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e
Add Step, ELU, SELU, Swish activation functions ( #117 )
...
* Add Step, ELU, SELU, Swish activation functions
This commit adds the Step, ELU, SELU and Swish activations functions
* add to the docs
* review
2023-12-11 17:04:07 -08:00
Awni Hannun
b9226c367c
Fix CI format + build issue ( #137 )
...
* fix ci
* Fix python bindings build
---------
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2023-12-11 15:01:41 -08:00
__mo_san__
072044e28f
fix and update binary cross entropy loss tests ( #133 )
...
* fix conflicts
* updated tests
2023-12-11 12:42:17 -08:00
Cyril Zakka, MD
e080290ba4
Added eye/identity ops ( #119 )
...
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
69505b4e9b
fixes ( #131 )
2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44
Add Binary Cross Entropy loss ( #122 )
...
* update BCE added tests for it ...
* added binary cross entropy loss to docs
* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid ( #108 )
...
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays ( #125 )
...
* bug fix
* format
2023-12-10 16:12:31 -08:00
Awni Hannun
2d0130f80f
fix loss tests ( #118 )
...
* fix loss tests
* use none as default
2023-12-10 10:08:19 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 ( #116 )
...
* Fix build on Xcode 14
* Style fixes
2023-12-10 06:58:52 -08:00
Enoch Kan
0b28399638
added mse_loss, nll_loss and kl_div_loss ( #98 )
...
* added mse_loss, nll_loss and kl_div_loss
* fixed axis not defined error in nll_loss
* fixed axis not defined in kl_div_loss
* added tests for mse, nll and kl_div
* modified docstrings and added reduce helper func
* updated docstring in kl_div_loss and moved helper func
* added new kl divergence implementation
* added reduction to test
* updated docstring of kl_div_loss with correct spelling
* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Awni Hannun
89b90dcfec
Pr template ( #99 )
...
* pr template
* format fix
2023-12-09 09:36:56 -08:00
Angelos Katharopoulos
fd836d891b
Hashable dtype and mlx.core prefixed repr ( #89 )
...
* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
2023-12-09 09:35:28 -08:00
ShiJZ
08d51bf232
Make it easier to test new optimizers implemented: no need to change test file manually ( #90 )
...
* add helper function get_all_optimizers() in test_optimizers.py
* remove unused import
2023-12-08 21:39:08 -08:00
Kai Ma
cb9e585b8e
Style fix for loss functions ( #91 )
...
* MLE and L1 loss functions
* logsoftmax change and tests
* subtract max logit for numerical stability
* l1 name change
* cross entropy reduction + unit tests
* docstrings
* l1 test name change
* old loss impl + default none
* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484
MLE and L1 loss functions ( #88 )
...
* MLE and L1 loss functions
* logsoftmax change and tests
* subtract max logit for numerical stability
* l1 name change
* cross entropy reduction + unit tests
* docstrings
* l1 test name change
* old loss impl + default none
2023-12-08 20:21:37 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op ( #85 )
...
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Zach Schillaci
5b9be57ac3
Add isort pre-commit and run ( #68 )
2023-12-08 11:31:47 -08:00
Angelos Katharopoulos
209404239b
Fix the accelerate dispatch for the power op ( #70 )
...
- The exponent and base were swapped because accelerate is using
exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug ( #6 )
...
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Angelos Katharopoulos
7546fdb100
Add CircleCI configuration ( #4 )
...
* Add CircleCI configuration
2023-12-04 16:04:11 -08:00
Awni Hannun
46a39e5b1f
copyright + ack
2023-11-30 11:12:53 -08:00
Jagrit Digani
e6306cfee9
jagrit's commit files
2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2
angelos's commit files
2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9
awni's commit files
2023-11-29 10:30:41 -08:00