mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences Updated fast attention: GEMM-ified with Steel primitives Uses flash attention 1 for scale correction * more compiler silencing * Address rebase issues * Templatize kernel instantiation, revise cpu bindings * Safer writes to output * Permit batch size > 1 * Numerical fixes for sdpa self attention * Re-enable test, remove unused variable * add benchmarking script * Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
@@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) {
|
||||
v (array): Input values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
Reference in New Issue
Block a user