Commit Graph

1356 Commits

Author SHA1 Message Date
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef version bump (#1154) v0.14.0 2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032 Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67 Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6 More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336 Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863 list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2 add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2 Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7 JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36 Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1 Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf Update README.md (#1136) 2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380 Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
Awni Hannun
6a9b584f3d patch bump (#1131) v0.13.1 2024-05-16 20:51:33 -07:00
Awni Hannun
81dd33af66 allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Awni Hannun
8b76571896 Fix extensions (#1126)
* fix extensions

* title

* enable circle

* fix nanobind tag

* fix bug in doc

* try to fix config

* typo
2024-05-16 15:36:25 -07:00
Angelos Katharopoulos
e78a6518fa Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01 Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jacket
c417e42116 [Fix] minor typo in default argument for argpartition's "axis" parameter (#1125)
According to the document, argpartition's axis parameter can be None, but due to a previous typo it can't really accepts a None value.
2024-05-15 15:25:25 -07:00
Jagrit Digani
358e1fd6ab Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
631dfbe673 fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Cheng
56a4eaed72 Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7 Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6 Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
5be5daa6ef Use compiled function in Sigmoid module (#1116) 2024-05-14 06:25:57 -07:00
Cheng
60cb11764e Use correct module type in quantized.py (#1115) 2024-05-14 06:25:42 -07:00
Cheng
cbd5445ea7 The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Cheng
2c7e9b5158 Add missing docs for some ops (#1110) 2024-05-14 06:09:05 -07:00
Mike Drob
2263e4b279 Experiment with medium machines for CI (#1000) 2024-05-13 19:40:19 -07:00
Awni Hannun
863039da4c Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111 No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Ravindra R. Jaju
e7f9710499 Fix typo in a variable name in example code. (#1104)
* Fix typo in a variable name in example code.

* Rename df2dx2 to d2fdx2 - the appropriate naming for the second derivative

* Update CONTRIBUTING.md - add needed python packages, and a virtual-env hint

* Revert "Fix typo in a variable name in example code."

This reverts commit bc10a17534.

* Rename df2dx2 to d2fdx2
2024-05-13 06:04:23 -07:00
Max-Heinrich Laves
ff4223904d Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6 improve error messaging in eval (#1101) 2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0 Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8bd6bfa4b5 version (#1099) v0.13.0 2024-05-09 17:52:39 -07:00
Awni Hannun
8b1906abd0 Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605 Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Awni Hannun
b21242faf1 Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4 Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66 Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
9814a2ae12 fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00