mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Double buffer keys for vector sdpa
This commit is contained in:
parent
fdadc4f22c
commit
11f73d6e89
@ -4,7 +4,7 @@ import math
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
L = 16384
|
L = 1024
|
||||||
H = 32
|
H = 32
|
||||||
H_k = H // 4
|
H_k = H // 4
|
||||||
D = 128
|
D = 128
|
||||||
|
@ -45,7 +45,7 @@ template <typename T, int D, int V = D>
|
|||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U q[qk_per_thread];
|
thread U q[qk_per_thread];
|
||||||
thread U k[qk_per_thread];
|
thread U k[2][qk_per_thread];
|
||||||
thread U o[v_per_thread];
|
thread U o[v_per_thread];
|
||||||
|
|
||||||
threadgroup U outputs[BN * BD];
|
threadgroup U outputs[BN * BD];
|
||||||
@ -86,8 +86,21 @@ template <typename T, int D, int V = D>
|
|||||||
U max_score = -INFINITY;
|
U max_score = -INFINITY;
|
||||||
U sum_exp_score = 0;
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
|
// Read the first key
|
||||||
|
short a = 0, b = 1;
|
||||||
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
|
k[a][j] = keys[j];
|
||||||
|
}
|
||||||
|
keys += inner_k_stride;
|
||||||
|
|
||||||
// For each key
|
// For each key
|
||||||
for (int i = simd_gid; i < N; i += BN) {
|
for (int i = simd_gid; i < N; i += BN) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Read the next key
|
||||||
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
|
k[b][j] = keys[j];
|
||||||
|
}
|
||||||
|
|
||||||
bool use_key = true;
|
bool use_key = true;
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||||
@ -95,15 +108,10 @@ template <typename T, int D, int V = D>
|
|||||||
use_key = bmask[0];
|
use_key = bmask[0];
|
||||||
}
|
}
|
||||||
if (use_key) {
|
if (use_key) {
|
||||||
// Read the key
|
|
||||||
for (int j = 0; j < qk_per_thread; j++) {
|
|
||||||
k[j] = keys[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the i-th score
|
// Compute the i-th score
|
||||||
U score = 0;
|
U score = 0;
|
||||||
for (int j = 0; j < qk_per_thread; j++) {
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
score += q[j] * k[j];
|
score += q[j] * k[a][j];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
if (float_mask) {
|
if (float_mask) {
|
||||||
@ -133,6 +141,11 @@ template <typename T, int D, int V = D>
|
|||||||
if (float_mask) {
|
if (float_mask) {
|
||||||
fmask += BN * mask_kv_seq_stride;
|
fmask += BN * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Swap read and write locations
|
||||||
|
short tmp = a;
|
||||||
|
b = a;
|
||||||
|
a = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each thread has a partial part of the output so we need to combine them.
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
@ -202,7 +215,7 @@ template <typename T, int D, int V = D>
|
|||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U q[qk_per_thread];
|
thread U q[qk_per_thread];
|
||||||
thread U k[qk_per_thread];
|
thread U k[2][qk_per_thread];
|
||||||
thread U o[v_per_thread];
|
thread U o[v_per_thread];
|
||||||
|
|
||||||
threadgroup U outputs[BN * BD];
|
threadgroup U outputs[BN * BD];
|
||||||
@ -248,8 +261,21 @@ template <typename T, int D, int V = D>
|
|||||||
U max_score = -1e9;
|
U max_score = -1e9;
|
||||||
U sum_exp_score = 0;
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
|
// Read the first key
|
||||||
|
short a = 0, b = 1;
|
||||||
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
|
k[a][j] = keys[j];
|
||||||
|
}
|
||||||
|
keys += inner_k_stride;
|
||||||
|
|
||||||
// For each key
|
// For each key
|
||||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Read the next key
|
||||||
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
|
k[b][j] = keys[j];
|
||||||
|
}
|
||||||
|
|
||||||
bool use_key = true;
|
bool use_key = true;
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||||
@ -257,15 +283,10 @@ template <typename T, int D, int V = D>
|
|||||||
use_key = bmask[0];
|
use_key = bmask[0];
|
||||||
}
|
}
|
||||||
if (use_key) {
|
if (use_key) {
|
||||||
// Read the key
|
|
||||||
for (int i = 0; i < qk_per_thread; i++) {
|
|
||||||
k[i] = keys[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the i-th score
|
// Compute the i-th score
|
||||||
U score = 0;
|
U score = 0;
|
||||||
for (int i = 0; i < qk_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
score += q[i] * k[i];
|
score += q[i] * k[a][i];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
if (float_mask) {
|
if (float_mask) {
|
||||||
@ -295,6 +316,11 @@ template <typename T, int D, int V = D>
|
|||||||
if (float_mask) {
|
if (float_mask) {
|
||||||
fmask += BN * blocks * mask_kv_seq_stride;
|
fmask += BN * blocks * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Swap read and write locations
|
||||||
|
short tmp = a;
|
||||||
|
b = a;
|
||||||
|
a = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each thread has a partial part of the output so we need to combine them.
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
|
Loading…
Reference in New Issue
Block a user