Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun
2025-01-07 20:24:53 -08:00
committed by GitHub
parent 516ded618b
commit d1766f2c70
5 changed files with 226 additions and 74 deletions

View File

@@ -4,6 +4,8 @@
using namespace metal;
constant bool has_mask [[function_constant(20)]];
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
@@ -15,6 +17,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -39,6 +44,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
@@ -54,34 +62,39 @@ template <typename T, int D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int j = 0; j < elem_per_thread; j++) {
k[j] = keys[j];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
score += q[j] * k[j];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int j = 0; j < elem_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}
// Move the pointers to the next kv
keys += stride;
values += stride;
if (has_mask) {
mask += BN * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.
@@ -126,6 +139,9 @@ template <typename T, int D>
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -155,6 +171,10 @@ template <typename T, int D>
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
}
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
@@ -171,34 +191,39 @@ template <typename T, int D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
if (!has_mask || mask[0]) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
}
}
// Each thread has a partial part of the output so we need to combine them.