Double buffer keys for vector sdpa

This commit is contained in:
Angelos Katharopoulos 2025-04-22 00:19:11 -07:00
parent fdadc4f22c
commit 11f73d6e89
2 changed files with 41 additions and 15 deletions

View File

@ -4,7 +4,7 @@ import math
import mlx.core as mx
from time_utils import time_fn
L = 16384
L = 1024
H = 32
H_k = H // 4
D = 128

View File

@ -45,7 +45,7 @@ template <typename T, int D, int V = D>
typedef float U;
thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U k[2][qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
@ -86,8 +86,21 @@ template <typename T, int D, int V = D>
U max_score = -INFINITY;
U sum_exp_score = 0;
// Read the first key
short a = 0, b = 1;
for (int j = 0; j < qk_per_thread; j++) {
k[a][j] = keys[j];
}
keys += inner_k_stride;
// For each key
for (int i = simd_gid; i < N; i += BN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read the next key
for (int j = 0; j < qk_per_thread; j++) {
k[b][j] = keys[j];
}
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
@ -95,15 +108,10 @@ template <typename T, int D, int V = D>
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j];
}
// Compute the i-th score
U score = 0;
for (int j = 0; j < qk_per_thread; j++) {
score += q[j] * k[j];
score += q[j] * k[a][j];
}
score = simd_sum(score);
if (float_mask) {
@ -133,6 +141,11 @@ template <typename T, int D, int V = D>
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
}
// Swap read and write locations
short tmp = a;
b = a;
a = tmp;
}
// Each thread has a partial part of the output so we need to combine them.
@ -202,7 +215,7 @@ template <typename T, int D, int V = D>
typedef float U;
thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U k[2][qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
@ -248,8 +261,21 @@ template <typename T, int D, int V = D>
U max_score = -1e9;
U sum_exp_score = 0;
// Read the first key
short a = 0, b = 1;
for (int j = 0; j < qk_per_thread; j++) {
k[a][j] = keys[j];
}
keys += inner_k_stride;
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read the next key
for (int j = 0; j < qk_per_thread; j++) {
k[b][j] = keys[j];
}
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
@ -257,15 +283,10 @@ template <typename T, int D, int V = D>
use_key = bmask[0];
}
if (use_key) {
// Read the key
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < qk_per_thread; i++) {
score += q[i] * k[i];
score += q[i] * k[a][i];
}
score = simd_sum(score);
if (float_mask) {
@ -295,6 +316,11 @@ template <typename T, int D, int V = D>
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
}
// Swap read and write locations
short tmp = a;
b = a;
a = tmp;
}
// Each thread has a partial part of the output so we need to combine them.