From 11f73d6e8910bbd6425143f61878f40501b94040 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 22 Apr 2025 00:19:11 -0700 Subject: [PATCH] Double buffer keys for vector sdpa --- benchmarks/python/sdpa_vector_bench.py | 2 +- mlx/backend/metal/kernels/sdpa_vector.h | 54 ++++++++++++++++++------- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 546bff84c..6aa65a2e2 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -4,7 +4,7 @@ import math import mlx.core as mx from time_utils import time_fn -L = 16384 +L = 1024 H = 32 H_k = H // 4 D = 128 diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index c4c0f6456..c57828634 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -45,7 +45,7 @@ template typedef float U; thread U q[qk_per_thread]; - thread U k[qk_per_thread]; + thread U k[2][qk_per_thread]; thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; @@ -86,8 +86,21 @@ template U max_score = -INFINITY; U sum_exp_score = 0; + // Read the first key + short a = 0, b = 1; + for (int j = 0; j < qk_per_thread; j++) { + k[a][j] = keys[j]; + } + keys += inner_k_stride; + // For each key for (int i = simd_gid; i < N; i += BN) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read the next key + for (int j = 0; j < qk_per_thread; j++) { + k[b][j] = keys[j]; + } + bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); @@ -95,15 +108,10 @@ template use_key = bmask[0]; } if (use_key) { - // Read the key - for (int j = 0; j < qk_per_thread; j++) { - k[j] = keys[j]; - } - // Compute the i-th score U score = 0; for (int j = 0; j < qk_per_thread; j++) { - score += q[j] * k[j]; + score += q[j] * k[a][j]; } score = simd_sum(score); if (float_mask) { @@ -133,6 +141,11 @@ template if (float_mask) { fmask += BN * mask_kv_seq_stride; } + + // Swap read and write locations + short tmp = a; + b = a; + a = tmp; } // Each thread has a partial part of the output so we need to combine them. @@ -202,7 +215,7 @@ template typedef float U; thread U q[qk_per_thread]; - thread U k[qk_per_thread]; + thread U k[2][qk_per_thread]; thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; @@ -248,8 +261,21 @@ template U max_score = -1e9; U sum_exp_score = 0; + // Read the first key + short a = 0, b = 1; + for (int j = 0; j < qk_per_thread; j++) { + k[a][j] = keys[j]; + } + keys += inner_k_stride; + // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Read the next key + for (int j = 0; j < qk_per_thread; j++) { + k[b][j] = keys[j]; + } + bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); @@ -257,15 +283,10 @@ template use_key = bmask[0]; } if (use_key) { - // Read the key - for (int i = 0; i < qk_per_thread; i++) { - k[i] = keys[i]; - } - // Compute the i-th score U score = 0; for (int i = 0; i < qk_per_thread; i++) { - score += q[i] * k[i]; + score += q[i] * k[a][i]; } score = simd_sum(score); if (float_mask) { @@ -295,6 +316,11 @@ template if (float_mask) { fmask += BN * blocks * mask_kv_seq_stride; } + + // Swap read and write locations + short tmp = a; + b = a; + a = tmp; } // Each thread has a partial part of the output so we need to combine them.