Add memory_efficient_threshold kwarg to sdpa kernel (#1319)

Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
Brian Keene
2024-08-12 15:57:09 -04:00
committed by GitHub
parent 9231617eb3
commit 19fb69e2ed
4 changed files with 13 additions and 4 deletions

View File

@@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),