Awni Hannun
13b26775f1
use minimum deployment target ( #2016 )
2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561
causal vector sdpa ( #2018 )
...
* causal vector sdpa
* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66
enable complex gemm ( #2017 )
2025-03-28 10:45:13 -07:00
Awni Hannun
bc62932984
sdpa specialization for head dim 256 ( #2007 )
2025-03-27 19:31:25 -07:00
Jagrit Digani
6a40e1c176
Fix looping limit in causal attention ( #1999 )
2025-03-24 12:28:00 -07:00
Jagrit Digani
9adcd1a650
Support fused masking in Attention ( #1924 )
...
* Update API to allow mask='causal' in fast::sdpa
* Add fallback
* Update steel::AttnParams
* Fix typo
* WIP, basic causal
* Update tests
* Update benchmarking
* Update masking loop limits
* Add bool masking and update tests
* Update additive mask
* Update benchmarks
* Update benchmarks
* Update tests
* Update for bfloat error
* Update early exit
* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
c6ea2ba329
Use same accumulation precision in gemv as gemm ( #1962 )
...
* use same accumulation precision in gemv as gemm
* faster
* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
3c3e558c60
Support transposed head/seq for kv ( #1950 )
...
* support transposed head/seq for kv
* fix flaky test
* nit
2025-03-10 10:53:45 -07:00
Alex Barron
fd0d63ba5b
Affine quant always in fp32 ( #1925 )
...
* do affine quant in fp32
* static cast
2025-03-04 17:50:19 -08:00
Awni Hannun
e613d0eaf0
SDPA support for small batch (over sequence) queries ( #1922 )
...
* batch query sdpa
* batch sdpa for query
2025-03-04 10:59:04 -08:00
Angelos Katharopoulos
5e6c130d93
RMS norm without scaling ( #1915 )
2025-02-28 20:26:57 -08:00
Jagrit Digani
89d327075f
Enabling fused attention for head dim 128 ( #1899 )
...
* Share KV smem
* Fix bfloat error
* Unroll O = S @ V loop
* Perf upgrade
* Remove commented out function
* Add -Wno-c++17-extensions flag to metal flags
* Add -Wno-c++17-extensions flag to metal extension flags
2025-02-26 10:02:06 -08:00
Awni Hannun
2d0f384b6f
fix simd erf_inv ( #1896 )
2025-02-24 13:57:47 -08:00
Angelos Katharopoulos
71de73a668
Fix convs by reverting #1803 ( #1882 )
2025-02-18 14:36:34 -08:00
Angelos Katharopoulos
1762793989
Remove unused uniform ( #1867 )
2025-02-14 15:51:41 -08:00
Jagrit Digani
2dc307f2e6
Winograd Update for Small batches ( #1803 )
...
* Build in padding to Winograd kernels
* Add new fused Winograd kernel
* Enable weight flipping in Winograd kernels
2025-02-14 13:08:13 -08:00
Alex Barron
5cd97f7ffe
Bitwise Inverse ( #1862 )
...
* add bitwise inverse
* add vmap + fix nojit
* inverse -> invert
* add to compile + remove unused
2025-02-13 08:44:14 -08:00
Awni Hannun
e425dc00c0
Faster small batch qmv ( #1861 )
...
* faster small batch qmv
* swap batch and block dims for qvm and qmv regular
2025-02-12 22:02:36 -08:00
Awni Hannun
af1b725fda
Fix a couple of slicing bugs ( #1827 )
...
* fix a few bugs
* fix conv grad
* speedup test
* comment
2025-02-05 19:50:08 -08:00
Awni Hannun
fe5987b81d
faster sort ( #1831 )
2025-02-05 06:10:22 -08:00
Angelos Katharopoulos
f5cc1eea72
Allow different value dimensions in sdpa_vector ( #1811 )
2025-01-31 20:58:59 -08:00
Awni Hannun
b7c9f1d38f
scatter axis + gather axis primitives ( #1813 )
...
* scatter axis + gather axis primitives
* add transforms
* comment
2025-01-31 20:48:08 -08:00
Awni Hannun
a4667da1eb
Faster synchronization Fence
primitive ( #1773 )
...
* try faster synchronization
move event
fixes
update bench
fix
fix
* non-functioning kernel
* try alternative fence
* cleanup barrier
* get rid of event_fence
* update benchmarks
* doc string in metal fence
2025-01-17 18:42:19 -08:00
Alex Barron
c7b0300af5
Fix batched qmv bug ( #1758 )
2025-01-09 11:45:57 -08:00
Awni Hannun
d1766f2c70
Add boolean mask support in vector SDPA ( #1757 )
2025-01-07 20:24:53 -08:00
Awni Hannun
516ded618b
Dynamic slicing ( #1741 )
...
* dynamic slice and slice update
* python bindings + tests + fix set item
* fix compile issue
* comment
* fix jit
2025-01-07 14:02:16 -08:00
Awni Hannun
259025100e
Fix nd ternary on GPU ( #1746 )
2025-01-03 11:52:17 -08:00
Awni Hannun
6fa0501387
Fix concatenate/slice_update vjp + reduce binary size ( #1735 )
...
* fix concatenate vjp + reduce binary size
* also cast in slice update
2025-01-02 16:36:33 -08:00
Awni Hannun
6bd28d246e
Allow no copy negative strides in as_strided and slice ( #1688 )
...
* allow no copy negative strides in as_strided and slice
* fix jit
* fix jit
2024-12-12 08:59:45 -08:00
Awni Hannun
40c62c1321
Use int64 stride everywhere ( #1671 )
...
* use int64 stride everywhere
* fix ext
* fix ext
* more shape + cleanup
* one more
* few more
2024-12-09 11:09:02 -08:00
Alex Barron
95c4a2e3af
add back conditionaltype ( #1655 )
2024-12-06 11:12:01 -08:00
Awni Hannun
211411faf2
fix large ops ( #1620 )
2024-11-24 09:17:10 -08:00
Alex Barron
6f7986d592
Cleaner qmv
/qvm
( #1616 )
2024-11-22 11:14:08 -08:00
Jagrit Digani
02bec0bb6d
Matrix Attention kernel ( #1610 )
...
* Rough INIT
* [WIP]: Loading and Matmuls added
* [WIP]: Reductions and min working aligned kernel at headdim = 64
* [WIP] Added headdim 80 for testing
* [WIP] Update dispatch params for testing
* [WIP] Add support for unaligned seq lengths - still looks messy
* Update sdpa_benchmarks
* Update sdpa_benchmarks
* Update sdpa_benchmarks
* Enable gqa support
* Update benchmark and switch off 128 headdim
* Update headdim 128 tuning
* Remove older fast attention code. Write out O strided
* Disable hd=128 until further optimizations
* Enable bf16
* Fix data size bug
* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c
3 and 6 bit quantization ( #1613 )
...
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b
Reduce specializations ( #1607 )
...
* start of reduce specializations
* fix all reduce
* fix many dims
* fix
* non-jit tests clear
* cleanup instantiations
* cpu merges
* change dim specializations
* optimize
* fix jit
* fix jit
* use higher precision for integer sum+prod
* fixes
2024-11-21 19:53:00 -08:00
Awni Hannun
2419edd5b2
Faster indexing math in a few kernels ( #1589 )
...
* wip: faster compiled kernels
* faster general unary with uint specialization
* index type in compiled, unary, binary, ternary, copy
* fix jit
* jit fix
* specialize gather + scatter
* nit in docs
2024-11-18 19:52:00 -08:00
Angelos Katharopoulos
073076ac7d
2-Pass Sdpa Inference Kernel ( #1597 )
2024-11-18 17:31:53 -08:00
Awni Hannun
610af352d4
Dispatch bf16 at run time when using the JIT ( #1584 )
...
* Dispatch bf16 at run time when using the JIT
* fix extension
* fix extension build
* fix extension build
* Update utils.h
2024-11-15 16:54:36 -08:00
Alex Barron
a4c47b0276
OOB QMV fix ( #1579 )
...
* fix oob access in qmv
* skip more
* fix small case
2024-11-08 17:59:45 -08:00
Alex Barron
111fefd5e9
Fix OOB access in qmv ( #1577 )
...
* fix oob access in qmv
* skip more
2024-11-08 15:41:30 -08:00
Awni Hannun
9f0d5c12fc
Fully wrap the command encoder ( #1572 )
...
* fully wrap the command encoder
* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Alex Barron
26be608470
Add split_k qvm
for long context ( #1564 )
...
* Add splitk qvm
* configurable splitk
* tuning
* remove extra instantiation
* remove refactor
* separate test
* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c
Reductions update ( #1351 )
2024-11-04 22:25:16 -08:00
Angelos Katharopoulos
62f297b51d
Sdpa fix ( #1558 )
2024-11-02 21:25:46 -07:00
Awni Hannun
4f72c66911
improvements to scatter / gather ( #1541 )
2024-10-30 19:30:54 -07:00
Jagrit Digani
960e3f0f05
Gemm update ( #1518 )
2024-10-30 19:30:28 -07:00
Awni Hannun
d3cd26820e
Faster bits and bernoulli ( #1535 )
...
* faster bits and bernoulli
* fix bernoulli
2024-10-28 11:11:00 -07:00
Angelos Katharopoulos
c9b41d460f
Working 64-bit scans ( #1506 )
2024-10-24 11:05:46 -07:00
Alex Barron
d15fa13daf
Batched Quantized Matmul + Fast Small QMV ( #1503 )
...
* add fast qmv for small dims
* fix test
* batched cpu
* add batched template param
* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00