From f5cc1eea7264e493f67806bc2416ffe7e61baba4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 31 Jan 2025 20:58:59 -0800 Subject: [PATCH] Allow different value dimensions in sdpa_vector (#1811) --- benchmarks/python/sdpa_vector_bench.py | 40 +++++++--- .../scaled_dot_product_attention.metal | 35 ++++++-- mlx/backend/metal/kernels/sdpa_vector.h | 80 ++++++++++--------- .../metal/scaled_dot_product_attention.cpp | 8 +- mlx/fast.cpp | 19 ++--- python/tests/test_fast_sdpa.py | 17 ++++ 6 files changed, 127 insertions(+), 72 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 74843a2f4..546bff84c 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -8,14 +8,23 @@ L = 16384 H = 32 H_k = H // 4 D = 128 +V = 128 dtype = mx.float16 loops = 10 -def attention(q, k, v, mask=None): +def upproject(x, w): + if w is None: + return x + else: + return x @ w.T + + +def attention(q, k, v, mask=None, w=None): def _sdpa(q, k, v): B, Hq, L, D = q.shape _, Hk, S, _ = k.shape + _, _, _, V = v.shape q = q.reshape(B, Hk, Hq // Hk, L, D) k = k[:, :, None, :, :] v = v[:, :, None, :, :] @@ -25,16 +34,18 @@ def attention(q, k, v, mask=None): s = mx.where(m, s, mx.finfo(s.dtype).min) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) o = p @ v - return o.reshape(B, Hq, L, D) + return o.reshape(B, Hq, L, V) for i in range(loops): q = _sdpa(q, k, v) + q = upproject(q, w) return q -def sdpa(q, k, v, mask=None): +def sdpa(q, k, v, mask=None, w=None): for i in range(loops): q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q = upproject(q, w) return q @@ -42,34 +53,37 @@ def time_self_attention_primitives(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) - time_fn(attention, q, k, v) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None + mx.eval(q, k, v, w) + time_fn(attention, q, k, v, w=w) def time_self_attention_sdpa(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - mx.eval(q, k, v) - time_fn(sdpa, q, k, v) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None + mx.eval(q, k, v, w) + time_fn(sdpa, q, k, v, w=w) def time_self_attention_sdpa_with_mask(): mx.random.seed(3) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) - v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) + w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mask = mx.full((L,), True) mask[L // 2 :] = False - mx.eval(q, k, v, mask) + mx.eval(q, k, v, mask, w) def sdpa_mask(*args): - return sdpa(*args, mask=mask) + return sdpa(*args, mask=mask, w=w) def attention_mask(*args): - return attention(*args, mask=mask) + return attention(*args, mask=mask, w=w) time_fn(attention_mask, q, k, v) time_fn(sdpa_mask, q, k, v) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index b5bc9607e..ea80396df 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -7,15 +7,34 @@ using namespace metal; // clang-format off // SDPA vector instantiations -#define instantiate_sdpa_vector(type, head_dim) \ - instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ - instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \ - instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim) +#define instantiate_sdpa_vector_aggregation(type, value_dim) \ + instantiate_kernel( \ + "sdpa_vector_2pass_2_" #type "_" #value_dim, \ + sdpa_vector_2pass_2, \ + type, \ + value_dim) -#define instantiate_sdpa_vector_heads(type) \ - instantiate_sdpa_vector(type, 64) \ - instantiate_sdpa_vector(type, 96) \ - instantiate_sdpa_vector(type, 128) +#define instantiate_sdpa_vector(type, qk_dim, value_dim) \ + instantiate_kernel( \ + "sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \ + sdpa_vector, \ + type, \ + qk_dim, \ + value_dim) \ + instantiate_kernel( \ + "sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \ + sdpa_vector_2pass_1, \ + type, \ + qk_dim, \ + value_dim) + +#define instantiate_sdpa_vector_heads(type) \ + instantiate_sdpa_vector(type, 64, 64) \ + instantiate_sdpa_vector(type, 96, 96) \ + instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector_aggregation(type, 64) \ + instantiate_sdpa_vector_aggregation(type, 96) \ + instantiate_sdpa_vector_aggregation(type, 128) instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 8a7351b7c..f5fe88f4a 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -6,7 +6,7 @@ using namespace metal; constant bool has_mask [[function_constant(20)]]; -template +template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], @@ -25,14 +25,16 @@ template uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; - constexpr int elem_per_thread = D / BD; - constexpr int stride = BN * D; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + constexpr int inner_k_stride = BN * D; + constexpr int inner_v_stride = BN * V; typedef float U; - thread U q[elem_per_thread]; - thread U k[elem_per_thread]; - thread U o[elem_per_thread]; + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; threadgroup U max_scores[BN]; @@ -41,19 +43,19 @@ template // Adjust positions const int head_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * elem_per_thread; - keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; - values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; + queries += head_idx * D + simd_lid * qk_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread; + values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride; } - out += head_idx * D + simd_gid * elem_per_thread; + out += head_idx * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { o[i] = 0; } @@ -64,13 +66,13 @@ template for (int i = simd_gid; i < N; i += BN) { if (!has_mask || mask[0]) { // Read the key - for (int j = 0; j < elem_per_thread; j++) { + for (int j = 0; j < qk_per_thread; j++) { k[j] = keys[j]; } // Compute the i-th score U score = 0; - for (int j = 0; j < elem_per_thread; j++) { + for (int j = 0; j < qk_per_thread; j++) { score += q[j] * k[j]; } score = simd_sum(score); @@ -84,14 +86,14 @@ template sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator - for (int j = 0; j < elem_per_thread; j++) { + for (int j = 0; j < v_per_thread; j++) { o[j] = o[j] * factor + exp_score * values[j]; } } // Move the pointers to the next kv - keys += stride; - values += stride; + keys += inner_k_stride; + values += inner_v_stride; if (has_mask) { mask += BN * mask_seq_stride; } @@ -111,7 +113,7 @@ template sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); // Now we need to aggregate all the outputs - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; @@ -120,13 +122,13 @@ template // And write the output if (simd_lid == 0) { - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { out[i] = static_cast(o[i]); } } } -template +template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], @@ -147,15 +149,17 @@ template uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 8; constexpr int BD = 32; - constexpr int elem_per_thread = D / BD; - constexpr int stride = BN * D; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + constexpr int inner_k_stride = BN * D; + constexpr int inner_v_stride = BN * V; constexpr int blocks = 32; typedef float U; - thread U q[elem_per_thread]; - thread U k[elem_per_thread]; - thread U o[elem_per_thread]; + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; threadgroup U max_scores[BN]; @@ -165,12 +169,12 @@ template const int block_idx = tid.z; const int head_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * elem_per_thread; + queries += head_idx * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + - simd_lid * elem_per_thread; - values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + - simd_lid * elem_per_thread; - out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; + simd_lid * qk_per_thread; + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V + + simd_lid * v_per_thread; + out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread; if (has_mask) { mask += head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_seq_stride; @@ -179,10 +183,10 @@ template maxs += head_idx * blocks + block_idx; // Read the query and 0 the output accumulator - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { o[i] = 0; } @@ -193,13 +197,13 @@ template for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { if (!has_mask || mask[0]) { // Read the key - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < qk_per_thread; i++) { k[i] = keys[i]; } // Compute the i-th score U score = 0; - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < qk_per_thread; i++) { score += q[i] * k[i]; } score = simd_sum(score); @@ -213,14 +217,14 @@ template sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { o[i] = o[i] * factor + exp_score * values[i]; } } // Move the pointers to the next kv - keys += blocks * stride; - values += blocks * stride; + keys += blocks * inner_k_stride; + values += blocks * inner_v_stride; if (has_mask) { mask += BN * blocks * mask_seq_stride; } @@ -247,7 +251,7 @@ template } // Now we need to aggregate all the outputs - for (int i = 0; i < elem_per_thread; i++) { + for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BN + simd_gid] = o[i] * fast::exp(max_scores[simd_gid] - new_max); threadgroup_barrier(mem_flags::mem_threadgroup); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 105a1e87a..47bf7f22a 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -124,6 +124,8 @@ void sdpa_vector( kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(v.shape(-1)); // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); @@ -185,6 +187,8 @@ void sdpa_vector_2pass( kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(v.shape(-1)); // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); @@ -256,7 +260,7 @@ void sdpa_vector_2pass( kname += "sdpa_vector_2pass_2_"; kname += get_type_string(q.dtype()); kname += "_"; - kname += std::to_string(q.shape(-1)); + kname += std::to_string(v.shape(-1)); // Get the kernel kernel = d.get_kernel(kname); @@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu( const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); // Donate the query if possible - if (q.is_donatable()) { + if (q.is_donatable() && q.size() == o.size()) { o.move_shared_buffer(q); } else { o.set_data(allocator::malloc_or_wait(o.nbytes())); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 14195b9e4..8a60322db 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -684,23 +684,20 @@ array scaled_dot_product_attention( const size_t query_head_dim = q.shape(-1); const size_t query_sequence_length = q.shape(2); - bool implementation_supports_use_case = query_head_dim == value_head_dim; - const bool sdpa_vector_supported_head_dim = - query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; - const bool sdpa_full_supported_head_dim = - query_head_dim == 64 || query_head_dim == 80; + query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); + const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 80); - const bool supports_sdpa_full = query_sequence_length >= threshold && - !mask.has_value() && sdpa_full_supported_head_dim && - stream.device == Device::gpu; + const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && + sdpa_full_supported_head_dim && stream.device == Device::gpu; - const bool supported_mask = !mask || (mask->dtype() == bool_); const bool supports_sdpa_vector = query_sequence_length == 1 && - supported_mask && sdpa_vector_supported_head_dim && + (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && stream.device == Device::gpu; - implementation_supports_use_case &= + const bool implementation_supports_use_case = supports_sdpa_full || supports_sdpa_vector; std::vector inputs = {q, k, v}; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index f1298cb35..348ba4c88 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -262,6 +262,23 @@ class TestFastSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + @unittest.skip("Different head and value dims is not enabled") + def test_fast_sdpa_vector_value_dims(self): + D = 192 + V = 128 + Nq = 4 + Nkv = 1 + scale = 1.0 + mx.random.seed(0) + + for L in [43, 128, 237, 8192]: + q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) + k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) + v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V)) + ref = mlx_primitives_sdpa(q, k, v, scale) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) + if __name__ == "__main__": unittest.main(failfast=True)