mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 04:08:13 +08:00
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
@@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
|
Reference in New Issue
Block a user