mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Double buffer keys for vector sdpa
This commit is contained in:
parent
fdadc4f22c
commit
11f73d6e89
@ -4,7 +4,7 @@ import math
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16384
|
||||
L = 1024
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
|
@ -45,7 +45,7 @@ template <typename T, int D, int V = D>
|
||||
typedef float U;
|
||||
|
||||
thread U q[qk_per_thread];
|
||||
thread U k[qk_per_thread];
|
||||
thread U k[2][qk_per_thread];
|
||||
thread U o[v_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
@ -86,8 +86,21 @@ template <typename T, int D, int V = D>
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// Read the first key
|
||||
short a = 0, b = 1;
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[a][j] = keys[j];
|
||||
}
|
||||
keys += inner_k_stride;
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read the next key
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[b][j] = keys[j];
|
||||
}
|
||||
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
@ -95,15 +108,10 @@ template <typename T, int D, int V = D>
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[j] = keys[j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
score += q[j] * k[a][j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
@ -133,6 +141,11 @@ template <typename T, int D, int V = D>
|
||||
if (float_mask) {
|
||||
fmask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
|
||||
// Swap read and write locations
|
||||
short tmp = a;
|
||||
b = a;
|
||||
a = tmp;
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
@ -202,7 +215,7 @@ template <typename T, int D, int V = D>
|
||||
typedef float U;
|
||||
|
||||
thread U q[qk_per_thread];
|
||||
thread U k[qk_per_thread];
|
||||
thread U k[2][qk_per_thread];
|
||||
thread U o[v_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
@ -248,8 +261,21 @@ template <typename T, int D, int V = D>
|
||||
U max_score = -1e9;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// Read the first key
|
||||
short a = 0, b = 1;
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[a][j] = keys[j];
|
||||
}
|
||||
keys += inner_k_stride;
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read the next key
|
||||
for (int j = 0; j < qk_per_thread; j++) {
|
||||
k[b][j] = keys[j];
|
||||
}
|
||||
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
@ -257,15 +283,10 @@ template <typename T, int D, int V = D>
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
for (int i = 0; i < qk_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < qk_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
score += q[i] * k[a][i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
@ -295,6 +316,11 @@ template <typename T, int D, int V = D>
|
||||
if (float_mask) {
|
||||
fmask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
|
||||
// Swap read and write locations
|
||||
short tmp = a;
|
||||
b = a;
|
||||
a = tmp;
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
Loading…
Reference in New Issue
Block a user