Commit Graph

42 Commits

Author SHA1 Message Date
Luca Arnaboldi
b93c4cf378
Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Stv.X
1e0c78b970
Fixed typo in some proprietary terms. (#161) 2023-12-13 19:48:00 -08:00
Awni Hannun
25f70d4ca4
Fix divide types + floor divide (//) (#138)
* divide types

* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
2023-12-11 17:04:07 -08:00
Awni Hannun
b9226c367c
Fix CI format + build issue (#137)
* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2023-12-11 15:01:41 -08:00
Angelos Katharopoulos
3214629601
Mlx array accessor (#128)
* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
2023-12-11 13:42:55 -08:00
__mo_san__
072044e28f
fix and update binary cross entropy loss tests (#133)
* fix conflicts

* updated tests
2023-12-11 12:42:17 -08:00
Cyril Zakka, MD
e080290ba4
Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
69505b4e9b
fixes (#131) 2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44
Add Binary Cross Entropy loss (#122)
* update BCE added tests for it ...

* added binary cross entropy loss to docs

* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Awni Hannun
2d0130f80f
fix loss tests (#118)
* fix loss tests

* use none as default
2023-12-10 10:08:19 -08:00
__mo_san__
c1e1c1443f
Added Adagrad optimizer (#102) 2023-12-10 09:22:39 -08:00
Henry Ansah
68bf1d7867
add nn module for sigmoid activation (#111)
* add nn module for sigmoid activation

* update .gitignore with .cache folder generated by jetbrains fleet ide

* remove .cache folder
2023-12-10 07:00:39 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
__mo_san__
ef7b8756c0
Add tanh activation function (#115)
* added Adagrad optimizer ...

* added Tanh activation function ...

* reformatted file ...

* remove unrelated stuff ...

* Update activations.py
2023-12-09 19:25:38 -08:00
Enoch Kan
0b28399638
added mse_loss, nll_loss and kl_div_loss (#98)
* added mse_loss, nll_loss and kl_div_loss

* fixed axis not defined error in nll_loss

* fixed axis not defined in kl_div_loss

* added tests for mse, nll and kl_div

* modified docstrings and added reduce helper func

* updated docstring in kl_div_loss and moved helper func

* added new kl divergence implementation

* added reduction to test

* updated docstring of kl_div_loss with correct spelling

* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Joe Barrow
ac6dc5d3eb
Adding optional bias param to MultiHeadAttention (#104)
* Adding optional  param to

* Run style-checker
2023-12-09 11:04:28 -08:00
Awni Hannun
89b90dcfec
Pr template (#99)
* pr template
* format fix
2023-12-09 09:36:56 -08:00
Angelos Katharopoulos
fd836d891b
Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
2023-12-09 09:35:28 -08:00
Awni Hannun
2520dbcf0a
add losses to the docs, fix black failur (#92) 2023-12-09 06:06:52 -08:00
Abe Leininger
430bfb4944
Adds Nesterov momentum to SGD (#87) 2023-12-08 23:23:36 -08:00
ShiJZ
08d51bf232
Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
2023-12-08 21:39:08 -08:00
Kai Ma
cb9e585b8e
Style fix for loss functions (#91)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none

* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484
MLE and L1 loss functions (#88)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none
2023-12-08 20:21:37 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00
Joe Barrow
69a24e6a1e
AdamW implementation (#72)
* AdamW implementation without bias correction
* Makes use of the underlying Adam implementation
2023-12-08 14:45:34 -08:00
Zach Schillaci
5b9be57ac3
Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
Angelos Katharopoulos
209404239b
Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
2023-12-08 10:58:03 -08:00
Zach Schillaci
d11d77e581
Spelling fixes in transformer.py (#59) 2023-12-07 13:32:09 -08:00
rushyam
2e126aeb7e
Feature Addition: Encoder-Decoder Transformer Architecture (#50)
* Implemented decoder-transformer-layer, decoder-transformer  and introduce encoder-decoder transformer

* added relu layer

* add src, tgt, memory mask

---------

Co-authored-by: rushyam <rushyam@rushyams-MacBook-Air.local>
2023-12-07 07:37:36 -08:00
Jagrit Digani
2440fe0124
NPY loading segfault bug (#34)
* Fixed Gil semantics in loading and saving from python file streams
2023-12-06 12:03:47 -08:00
Markus Enzweiler
2ffaee0c0d
Updated default argument for stride to 1 in Conv2d() in the docstring (#22) 2023-12-06 07:17:58 -08:00
Jagrit Digani
d518b3b6a5
Fix gemv broadcasting bug (#6)
* Fix broadcasting bug in gemv
* Add relevant tests in test_blas.py
2023-12-05 14:15:43 -08:00
Angelos Katharopoulos
7546fdb100
Add CircleCI configuration (#4)
* Add CircleCI configuration
2023-12-04 16:04:11 -08:00
Awni Hannun
db487e6b1a format 2023-11-30 11:50:50 -08:00
Awni Hannun
46a39e5b1f copyright + ack 2023-11-30 11:12:53 -08:00
Jagrit Digani
e6306cfee9 jagrit's commit files 2023-11-29 10:52:08 -08:00
Angelos Katharopoulos
d1f86272a2 angelos's commit files 2023-11-29 10:42:59 -08:00
Awni Hannun
8ca7f9e8e9 awni's commit files 2023-11-29 10:30:41 -08:00