Fangjun Kuang
f20e97b092
minor fixes ( #1194 )
...
* minor fixes
* fix build errors
2024-06-12 22:06:49 -07:00
Alex Barron
934683088e
Refactor JIT for unary/binary/ternary ops ( #1206 )
...
* refactor unary/binary/ternary ops
* get_primitive_string util
---------
2024-06-12 14:22:12 -07:00
Awni Hannun
de2b9e7d0a
Fix kernel deps to reduce build times ( #1205 )
2024-06-12 11:17:39 -07:00
Alex Barron
dd7d8e5e29
Add Quantized Ops to the JIT ( #1204 )
...
* JIT for quantized ops
* remove unused imports
* address comments
* fix imports
* second attempt to fix imports
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb
fix scatter + test ( #1202 )
...
* fix scatter + test
* fix test warnings
* fix metal validation
2024-06-11 14:35:12 -07:00
Awni Hannun
709ccc6800
install mpi for release build ( #1199 )
2024-06-10 10:09:32 -07:00
Awni Hannun
cf236fc390
version ( #1191 )
2024-06-06 17:16:40 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT ( #1102 )
...
* feature complete metal fft
* fix contiguity bug
* jit fft
* simplify rader/bluestein constant computation
* remove kernel/utils.h dep
* remove bf16.h dep
* format
---------
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409
Add docstring for scatter ( #1189 )
...
* Add docstring for scatter
* docs nits
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a
Add docs for the distributed namespace ( #1184 )
2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c
fix jit scan when output doesn't have primitive ( #1190 )
2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d
Fix scan ( #1188 )
...
* fix scan
* improve grid size
* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893
Fix the hard-shrink test ( #1185 )
2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f
Add softmin, hardshrink, hardtanh ( #1180 )
...
---------
Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d
Fix Metal API validation for empty concat ( #1183 )
2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc
Add some internal GPU apis ( #1177 )
...
* Add unary/binary/ternay/slice/concat internal GPU ops
* add pad internal op
* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4
Add view op ( #1179 )
...
* add view primitive
* nit
* fix view
2024-06-04 08:05:27 -07:00
nicolov
81def6ac76
Fix benchmark ( #1175 )
2024-06-04 07:50:46 -07:00
Angelos Katharopoulos
3de8ce3f3c
In place all-reduce and forgiving init ( #1178 )
2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24
Add defines include ( #1176 )
...
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30
Metal shaders for memory efficient self attention on large sequences ( #964 )
...
* Metal shaders for efficient self attention on large sequences
Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction
* more compiler silencing
* Address rebase issues
* Templatize kernel instantiation, revise cpu bindings
* Safer writes to output
* Permit batch size > 1
* Numerical fixes for sdpa self attention
* Re-enable test, remove unused variable
* add benchmarking script
* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Dominik Schlösser
3576b547c5
Doc error for default for scale in SinusoidalPositionalEncoding ( #1174 )
2024-06-02 13:42:45 -07:00
Awni Hannun
079882495d
version bump ( #1172 )
2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db
feat: Added dlpack device ( #1165 )
...
* feat: Added dlpack device
* feat: Added device_id to dlpack device
* feat: Added device_id to dlpack device
* doc: updated conversion docs
* doc: updated numpy.rst dlpack information
* doc: updated numpy.rst dlpack information
* Update docs/src/usage/numpy.rst
* Update docs/src/usage/numpy.rst
---------
Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b
stable cumprod grad at 0 ( #1167 )
2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46
Fix multi-block sort stride management ( #1169 )
...
* Fix multi-block sort stride management
* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d
Fix matvec vector stride bug ( #1168 )
2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1
Fix a couple bugs ( #1161 )
...
* fix jit reduce for RMS norm
* make strides a single buffer
* better eval error message
* fix compiling with inf and bf16
* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1
fix broadcast bug in bitwise ops ( #1157 )
2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef
version bump ( #1154 )
2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv ( #1139 )
2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67
Float mask update ( #1152 )
...
* Float mask update
* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db
Comms ( #1097 )
...
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6
More jitting ( #1132 )
...
* docs + circle min size build
* jit scan, arange, softmax
* add sort
* jit reductions
* remove print
* fix deps
* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336
Add groups to 2-D convolutions ( #1129 )
...
* Added groups to 2-D convolutions. Only implemented for **some** specializations.
Also fixed 1D grouped convs with different kernel strides and added more tests.
* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863
list based indexing ( #1150 )
2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2
add mx.trace ( #1143 ) ( #1147 )
...
* working c++ trace implementation
* updated throw + added overloads
* added python binding for trace function
* pre-commit reformatting
* add trace to docs
* resolve comments
* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2
Fix offset bug for device buffers ( #1151 )
...
* fix bug with large offsets for buffers
* add a test
* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization ( #1091 )
...
* try cpp 20 for compile
* unary, binary, ternary in jit
* nits
* fix gather/scatter
* fix rebase
* reorg compile
* add ternary to compile
* jit copy
* jit compile flag
* fix build
* use linked function for ternary
* some nits
* docs + circle min size build
* docs + circle min size build
* fix extension
* fix no cpu build
* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36
Rename block sparse ( #1149 )
...
* block_sparse_mm to gather_mm
* rename
* nit
* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1
Some fixes in docs ( #1141 )
...
* fixes in docs
* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb
Improve qvm speed ( #1140 )
2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be
Treate 'minimum' differently in cosine decay ( #1138 )
2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf
Update README.md ( #1136 )
2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada
Fix copy bug with many dims ( #1137 )
2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e
Choose the right MLX bf16 for extensions ( #1135 )
...
* default to custom bf
* choose right bf
* fix extensions
* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380
Implemented Cholesky on CPU ( #1119 )
2024-05-17 12:31:59 -07:00
Awni Hannun
6a9b584f3d
patch bump ( #1131 )
2024-05-16 20:51:33 -07:00
Awni Hannun
81dd33af66
allow conversion to dlpack ( #1120 )
2024-05-16 16:11:37 -07:00
Awni Hannun
8b76571896
Fix extensions ( #1126 )
...
* fix extensions
* title
* enable circle
* fix nanobind tag
* fix bug in doc
* try to fix config
* typo
2024-05-16 15:36:25 -07:00