mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 16:11:14 +08:00
Add memory_efficient_threshold kwarg to sdpa kernel (#1319)
Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency.
This commit is contained in:
parent
9231617eb3
commit
19fb69e2ed
13
mlx/fast.cpp
13
mlx/fast.cpp
@ -465,6 +465,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<int>& memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@ -535,6 +536,11 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1e6;
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
@ -581,9 +587,10 @@ array scaled_dot_product_attention(
|
||||
bool implementation_supports_use_case =
|
||||
supports_sdpa || supports_full_self_attention;
|
||||
|
||||
// disabling full self attention until perf is tuned;
|
||||
// likewise for sdpa
|
||||
implementation_supports_use_case &= false;
|
||||
// sdpa gpu shader is disabled except for memory efficient opt-in
|
||||
const int seq_for_threshold = queries.shape(2);
|
||||
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
|
||||
implementation_supports_use_case &= use_memory_efficient_impl;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
|
@ -37,6 +37,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
|
@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
|
@ -86,7 +86,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
||||
q_mlx, k_mlx, v_mlx, scale=scale
|
||||
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
|
||||
)
|
||||
|
||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
||||
|
Loading…
Reference in New Issue
Block a user