Metal shaders for memory efficient self attention on large sequences (#964)

* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
Brian Keene
2024-06-03 12:16:19 -04:00
committed by GitHub
parent 3576b547c5
commit 1865299a30
7 changed files with 1244 additions and 9 deletions

View File

@@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) {
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");