mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add boolean mask support in vector SDPA (#1757)
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
@@ -15,6 +17,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -39,6 +44,9 @@ template <typename T, int D>
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||
}
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
@@ -54,34 +62,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
k[j] = keys[j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
o[j] = o[j] * factor + exp_score * values[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += stride;
|
||||
values += stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
@@ -126,6 +139,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -155,6 +171,10 @@ template <typename T, int D>
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||
}
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
@@ -171,34 +191,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride;
|
||||
values += blocks * stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
Reference in New Issue
Block a user