mlx/mlx/backend/metal/kernels/sdpa_vector.h
2025-04-22 00:19:11 -07:00

419 lines
13 KiB
C++

// Copyright © 2024 Apple Inc.
#include <metal_simdgroup>
using namespace metal;
constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int& gqa_factor [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale [[buffer(10)]],
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int qk_per_thread = D / BD;
constexpr int v_per_thread = V / BD;
int inner_k_stride = BN * int(k_seq_stride);
int inner_v_stride = BN * int(v_seq_stride);
typedef float U;
thread U q[qk_per_thread];
thread U k[2][qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
out += o_offset * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// Read the first key
short a = 0, b = 1;
for (int j = 0; j < qk_per_thread; j++) {
k[a][j] = keys[j];
}
keys += inner_k_stride;
// For each key
for (int i = simd_gid; i < N; i += BN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read the next key
for (int j = 0; j < qk_per_thread; j++) {
k[b][j] = keys[j];
}
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Compute the i-th score
U score = 0;
for (int j = 0; j < qk_per_thread; j++) {
score += q[j] * k[a][j];
}
score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor + exp_score * values[j];
}
}
// Move the pointers to the next kv
keys += inner_k_stride;
values += inner_v_stride;
if (bool_mask) {
bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
}
// Swap read and write locations
short tmp = a;
b = a;
a = tmp;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < v_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale [[buffer(12)]],
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_kv_seq_stride
[[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int qk_per_thread = D / BD;
constexpr int v_per_thread = V / BD;
int inner_k_stride = BN * int(k_seq_stride);
int inner_v_stride = BN * int(v_seq_stride);
constexpr int blocks = 32;
typedef float U;
thread U q[qk_per_thread];
thread U k[2][qk_per_thread];
thread U o[v_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = head_idx / gqa_factor;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride +
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) {
bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
sums += o_offset * blocks + block_idx;
maxs += o_offset * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
}
U max_score = -1e9;
U sum_exp_score = 0;
// Read the first key
short a = 0, b = 1;
for (int j = 0; j < qk_per_thread; j++) {
k[a][j] = keys[j];
}
keys += inner_k_stride;
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read the next key
for (int j = 0; j < qk_per_thread; j++) {
k[b][j] = keys[j];
}
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
}
if (use_key) {
// Compute the i-th score
U score = 0;
for (int i = 0; i < qk_per_thread; i++) {
score += q[i] * k[a][i];
}
score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < v_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
}
// Move the pointers to the next kv
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
if (bool_mask) {
bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
}
// Swap read and write locations
short tmp = a;
b = a;
a = tmp;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);
// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);
// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_2(
const device float* partials [[buffer(0)]],
const device float* sums [[buffer(1)]],
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int blocks = 32;
typedef float U;
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int n_heads = tpg.x;
const int q_offset = n_heads * q_seq_idx + head_idx;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;
out += q_offset * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
// Now read the block into registers and then use shared memory to transpose
// it
for (int i = 0; i < elem_per_thread; i++) {
o[i] = partials[i];
}
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}