Commit Graph

20 Commits

Author SHA1 Message Date
Awni Hannun
de5f38fd48
Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Awni Hannun
b7c9f1d38f
scatter axis + gather axis primitives (#1813)
* scatter axis + gather axis primitives

* add transforms

* comment
2025-01-31 20:48:08 -08:00
Angelos Katharopoulos
1f4c127fb9
Move some kernels to get_template_definition (#1782) 2025-01-21 08:59:44 -08:00
Awni Hannun
40c62c1321
Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
2419edd5b2
Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
2024-11-18 19:52:00 -08:00
Awni Hannun
4f72c66911
improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Angelos Katharopoulos
c9b41d460f
Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
Awni Hannun
4f46e9c997
More fixes for arrays with large sizes (#1405)
* compile works for big arrays when contiguous

* style

* nits in docs

* a bunch more stuff

* update jit

* update jit

* use constant for shapes and strides and remove elem_to_loc overload

* use kernel instantiation

* docs nits

* update binary and ternary

* comments
2024-09-17 12:46:31 -07:00
Alex Barron
28be4de7c2
Fix JIT reductions (#1373) 2024-08-28 16:39:11 -07:00
Awni Hannun
df3233454d
2d gather specialization (#1339) 2024-08-22 10:48:24 -07:00
Awni Hannun
30bbea2f08
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
a3c287354f
Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Alex Barron
2615660e62
Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Alex Barron
934683088e
Refactor JIT for unary/binary/ternary ops (#1206)
* refactor unary/binary/ternary ops

* get_primitive_string util

---------
2024-06-12 14:22:12 -07:00
Alex Barron
dd7d8e5e29
Add Quantized Ops to the JIT (#1204)
* JIT for quantized ops

* remove unused imports

* address comments

* fix imports

* second attempt to fix imports

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-12 09:47:12 -07:00
Awni Hannun
df964132fb
fix scatter + test (#1202)
* fix scatter + test

* fix test warnings

* fix metal validation
2024-06-11 14:35:12 -07:00
Alex Barron
27d70c7d9d
Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
Awni Hannun
7e26fd8032
Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Awni Hannun
0189ab6ab6
More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Awni Hannun
226748b3e7
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00