mlx/benchmarks/python
Brian Keene 1865299a30
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
..
blas Update GEMM (#424) 2024-01-17 12:42:39 -08:00
comparative Reduce update (#783) 2024-03-04 19:09:51 -08:00
batch_matmul_bench.py Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
compile_bench.py Shapeless compilation for some graphs (#687) 2024-02-19 21:43:54 -08:00
conv1d_bench.py Add groups to Conv1d (#948) 2024-04-27 06:24:57 -07:00
conv_bench.py Add groups to 2-D convolutions (#1129) 2024-05-22 20:01:44 -07:00
fft_bench.py Metal FFT for powers of 2 up to 2048 (#915) 2024-04-11 21:40:06 -07:00
gather_bench.py Scatter optimization : Eliminate 64b integer divide. (#662) 2024-02-10 08:49:51 -08:00
layer_norm_bench.py Implement vjps for some primitives in the fast namespace (#883) 2024-03-26 16:35:34 -07:00
rms_norm_bench.py Implement vjps for some primitives in the fast namespace (#883) 2024-03-26 16:35:34 -07:00
rope_bench.py Fix copy donation and add partial rope (#881) 2024-03-22 17:28:26 -07:00
scatter_bench.py Up to 10x faster scatter. (#709) 2024-02-21 11:09:30 -08:00
sdpa_bench.py Metal shaders for memory efficient self attention on large sequences (#964) 2024-06-03 09:16:19 -07:00
single_ops.py Propagate nans in binary ops (#579) 2024-01-29 11:19:38 -08:00
time_utils.py Shapeless compilation for some graphs (#687) 2024-02-19 21:43:54 -08:00