From 19fb69e2ed3344554b50d9435c95083906703e69 Mon Sep 17 00:00:00 2001 From: Brian Keene Date: Mon, 12 Aug 2024 15:57:09 -0400 Subject: [PATCH] Add memory_efficient_threshold kwarg to sdpa kernel (#1319) Allows opt-in to memory efficient GPU shader at proscribed sequence length. Otherwise, utilizes aggregate MLX primitives for best latency. --- mlx/fast.cpp | 13 ++++++++++--- mlx/fast.h | 1 + python/src/fast.cpp | 1 + python/tests/test_fast_sdpa.py | 2 +- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 4a4819b5a..e067c9b4b 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -465,6 +465,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::optional& mask, + const std::optional& memory_efficient_threshold, StreamOrDevice s) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { @@ -535,6 +536,11 @@ array scaled_dot_product_attention( * * dtype is not fp32 or fp16 */ + int threshold = 1e6; + if (memory_efficient_threshold.has_value()) { + threshold = std::max(1, memory_efficient_threshold.value()); + } + bool needs_mask = mask.has_value(); auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( const std::vector& inputs) { @@ -581,9 +587,10 @@ array scaled_dot_product_attention( bool implementation_supports_use_case = supports_sdpa || supports_full_self_attention; - // disabling full self attention until perf is tuned; - // likewise for sdpa - implementation_supports_use_case &= false; + // sdpa gpu shader is disabled except for memory efficient opt-in + const int seq_for_threshold = queries.shape(2); + bool use_memory_efficient_impl = seq_for_threshold >= threshold; + implementation_supports_use_case &= use_memory_efficient_impl; if (implementation_supports_use_case) { auto out_shape = diff --git a/mlx/fast.h b/mlx/fast.h index 4c63df8fe..48e95a768 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -37,6 +37,7 @@ array scaled_dot_product_attention( const array& values, const float scale, const std::optional& mask = std::nullopt, + const std::optional& memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); std::tuple affine_quantize( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index be14c1d85..349618c23 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) { nb::kw_only(), "scale"_a, "mask"_a = nb::none(), + "memory_efficient_threshold"_a = nb::none(), "stream"_a = nb::none(), nb::sig( "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"), diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 4baed56d7..13b316bd1 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -86,7 +86,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( - q_mlx, k_mlx, v_mlx, scale=scale + q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2 ) self.assertListEqual(list(reference.shape), list(o_mlx.shape))