Commit Graph

588 Commits

Author SHA1 Message Date
Awni Hannun
0e5807bbcb
include optional (#202) 2023-12-17 22:01:35 -08:00
Cyril Zakka, MD
8eb56beb3a
Added clip function (#159)
* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-17 20:00:29 -08:00
Awni Hannun
ee0c2835c5
Docs updates (#198)
Reorganize NN docs + a few other tidbits.
2023-12-17 13:20:55 -08:00
Awni Hannun
90d04072b7
fix build w/ flatten (#195) 2023-12-17 11:58:45 -08:00
__mo_san__
52e1589a52
implemented Flatten Module (#149)
* implemented flatten op

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 21:54:37 -08:00
YUN, Junwoo
eebd7c275d
Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes (#142)
* Add AdaMax, AdaDelta, RMSprop
2023-12-16 21:43:15 -08:00
Austin Liu
a67bbfe745
Update docs (#177) (#190)
* update docs (fix #177)

* reorder

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-16 06:52:18 -08:00
Awni Hannun
104c34f906
setite negative indexing bug (#189) 2023-12-16 06:44:47 -08:00
Diogo
dc2edc762c
added tri / tril / triu (#170)
* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-15 17:30:34 -08:00
Awni Hannun
2e02acdc83
add base kwarg to rope (#186) 2023-12-15 16:47:59 -08:00
Ronan Collobert
83f266c44c
Lazy metal_device_ initialization (#185)
This ensures it is defined when the Scheduler needs it.
2023-12-15 16:06:46 -08:00
Víctor Aguilar
f24200db2c
accross -> across (#183) 2023-12-15 13:46:50 -08:00
Jason
e28b57e371
Added mx.stack c++ frontend impl (#123)
* stack C++ operation + python bindings
2023-12-14 13:21:19 -08:00
Awni Hannun
e5851e52b1
Add move and swap axis, and vmap for slice, concat, and gather (#158)
* add move and swap axis, and vmap for slice, concat, and gather
2023-12-14 12:59:12 -08:00
Diogo
f55908bc48
Added stubs for python files generated from C++ (#136)
* added pybind11-stubgen

* docs for generating stubs

* added line to readme
2023-12-14 12:58:45 -08:00
Luca Arnaboldi
b93c4cf378
Floor and Ceil (#150)
* Implements Floor and Ceil Ops
2023-12-14 10:00:23 -08:00
Stv.X
1e0c78b970
Fixed typo in some proprietary terms. (#161) 2023-12-13 19:48:00 -08:00
Awni Hannun
76e1af0e02
bump version (#157) 2023-12-13 14:28:26 -08:00
Ikko Eltociear Ashimine
c3272d4917
Update conv.cpp (#145)
Peform -> Perform
2023-12-12 11:27:49 -08:00
SputNikPlop
50f5d14b11
fix: tidy pull request template (#143)
* fix: tidy pull request template

* fix: feedback from awni
2023-12-12 08:14:39 -08:00
noahsmartin
d14a0e4ff9
Docs update (#144) 2023-12-12 07:53:42 -08:00
Diogo
fb675de30d
Run lint check for prs (#139) 2023-12-12 00:23:33 -08:00
Awni Hannun
25f70d4ca4
Fix divide types + floor divide (//) (#138)
* divide types

* fix black + test
2023-12-11 20:20:58 -08:00
Diogo
02de234ef0
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
2023-12-11 19:40:57 -08:00
Nicholas Santavas
f5df47ec6e
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
2023-12-11 17:04:07 -08:00
Awni Hannun
b9226c367c
Fix CI format + build issue (#137)
* fix ci

* Fix python bindings build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2023-12-11 15:01:41 -08:00
Angelos Katharopoulos
3214629601
Mlx array accessor (#128)
* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
2023-12-11 13:42:55 -08:00
__mo_san__
072044e28f
fix and update binary cross entropy loss tests (#133)
* fix conflicts

* updated tests
2023-12-11 12:42:17 -08:00
Cyril Zakka, MD
e080290ba4
Added eye/identity ops (#119)
`eye` and `identity` C++ and Python ops
2023-12-11 12:38:17 -08:00
Awni Hannun
69505b4e9b
fixes (#131) 2023-12-11 09:26:49 -08:00
__mo_san__
f4ddd7dc44
Add Binary Cross Entropy loss (#122)
* update BCE added tests for it ...

* added binary cross entropy loss to docs

* resolving conflicts for merge
2023-12-11 07:55:18 -08:00
Jason
b0cd092b7f
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
2023-12-10 16:31:38 -08:00
Awni Hannun
71d1fff90a
Bug fix in metal binary kernel dispatch for large arrays (#125)
* bug fix

* format
2023-12-10 16:12:31 -08:00
Yiyang(Steven) Yu
0cfbfc9904
Update README.md (#121) 2023-12-10 14:47:37 -08:00
Awni Hannun
2d0130f80f
fix loss tests (#118)
* fix loss tests

* use none as default
2023-12-10 10:08:19 -08:00
__mo_san__
c1e1c1443f
Added Adagrad optimizer (#102) 2023-12-10 09:22:39 -08:00
Henry Ansah
68bf1d7867
add nn module for sigmoid activation (#111)
* add nn module for sigmoid activation

* update .gitignore with .cache folder generated by jetbrains fleet ide

* remove .cache folder
2023-12-10 07:00:39 -08:00
Angelos Katharopoulos
600db7d754
Fix build on Xcode 14 (#116)
* Fix build on Xcode 14

* Style fixes
2023-12-10 06:58:52 -08:00
__mo_san__
ef7b8756c0
Add tanh activation function (#115)
* added Adagrad optimizer ...

* added Tanh activation function ...

* reformatted file ...

* remove unrelated stuff ...

* Update activations.py
2023-12-09 19:25:38 -08:00
Enoch Kan
0b28399638
added mse_loss, nll_loss and kl_div_loss (#98)
* added mse_loss, nll_loss and kl_div_loss

* fixed axis not defined error in nll_loss

* fixed axis not defined in kl_div_loss

* added tests for mse, nll and kl_div

* modified docstrings and added reduce helper func

* updated docstring in kl_div_loss and moved helper func

* added new kl divergence implementation

* added reduction to test

* updated docstring of kl_div_loss with correct spelling

* added losses to nn.rst in docs
2023-12-09 14:25:03 -08:00
Joe Barrow
ac6dc5d3eb
Adding optional bias param to MultiHeadAttention (#104)
* Adding optional  param to

* Run style-checker
2023-12-09 11:04:28 -08:00
Awni Hannun
89b90dcfec
Pr template (#99)
* pr template
* format fix
2023-12-09 09:36:56 -08:00
Angelos Katharopoulos
fd836d891b
Hashable dtype and mlx.core prefixed repr (#89)
* Make dtype hashable
* Add mlx.core prefix to our dtypes' repr
* Update the dtype test
2023-12-09 09:35:28 -08:00
AtomicVar
976e8babbe
Use compiled black as the pre-commit formatter (#94) 2023-12-09 07:06:46 -08:00
Awni Hannun
2520dbcf0a
add losses to the docs, fix black failur (#92) 2023-12-09 06:06:52 -08:00
Abe Leininger
430bfb4944
Adds Nesterov momentum to SGD (#87) 2023-12-08 23:23:36 -08:00
ShiJZ
08d51bf232
Make it easier to test new optimizers implemented: no need to change test file manually (#90)
* add helper function get_all_optimizers() in test_optimizers.py

* remove unused import
2023-12-08 21:39:08 -08:00
Kai Ma
cb9e585b8e
Style fix for loss functions (#91)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none

* style
2023-12-08 21:11:56 -08:00
Kai Ma
641d316484
MLE and L1 loss functions (#88)
* MLE and L1 loss functions

* logsoftmax change and tests

* subtract max logit for numerical stability

* l1 name change

* cross entropy reduction + unit tests

* docstrings

* l1 test name change

* old loss impl + default none
2023-12-08 20:21:37 -08:00
Angelos Katharopoulos
2b714714e1
Add the remainder op (#85)
* Add remainder in the C++ backend
* Add the python binding and test
2023-12-08 15:08:52 -08:00