mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
419 lines
13 KiB
C++
419 lines
13 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#include <metal_simdgroup>
|
|
|
|
using namespace metal;
|
|
|
|
constant bool has_mask [[function_constant(20)]];
|
|
constant bool query_transposed [[function_constant(21)]];
|
|
constant bool do_causal [[function_constant(22)]];
|
|
constant bool bool_mask [[function_constant(23)]];
|
|
constant bool float_mask [[function_constant(24)]];
|
|
|
|
template <typename T, int D, int V = D>
|
|
[[kernel]] void sdpa_vector(
|
|
const device T* queries [[buffer(0)]],
|
|
const device T* keys [[buffer(1)]],
|
|
const device T* values [[buffer(2)]],
|
|
device T* out [[buffer(3)]],
|
|
const constant int& gqa_factor [[buffer(4)]],
|
|
const constant int& N [[buffer(5)]],
|
|
const constant size_t& k_head_stride [[buffer(6)]],
|
|
const constant size_t& k_seq_stride [[buffer(7)]],
|
|
const constant size_t& v_head_stride [[buffer(8)]],
|
|
const constant size_t& v_seq_stride [[buffer(9)]],
|
|
const constant float& scale [[buffer(10)]],
|
|
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
|
|
const device T* fmask [[buffer(12), function_constant(float_mask)]],
|
|
const constant int& mask_kv_seq_stride
|
|
[[buffer(13), function_constant(has_mask)]],
|
|
const constant int& mask_q_seq_stride
|
|
[[buffer(14), function_constant(has_mask)]],
|
|
const constant int& mask_head_stride
|
|
[[buffer(15), function_constant(has_mask)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 tpg [[threadgroups_per_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int BN = 32;
|
|
constexpr int BD = 32;
|
|
constexpr int qk_per_thread = D / BD;
|
|
constexpr int v_per_thread = V / BD;
|
|
int inner_k_stride = BN * int(k_seq_stride);
|
|
int inner_v_stride = BN * int(v_seq_stride);
|
|
|
|
typedef float U;
|
|
|
|
thread U q[qk_per_thread];
|
|
thread U k[2][qk_per_thread];
|
|
thread U o[v_per_thread];
|
|
|
|
threadgroup U outputs[BN * BD];
|
|
threadgroup U max_scores[BN];
|
|
threadgroup U sum_exp_scores[BN];
|
|
|
|
// Adjust positions
|
|
const int head_idx = tid.x;
|
|
const int q_seq_idx = tid.y;
|
|
const int kv_head_idx = head_idx / gqa_factor;
|
|
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
|
const int q_offset =
|
|
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
|
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
|
simd_lid * qk_per_thread;
|
|
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
|
simd_lid * v_per_thread;
|
|
if (bool_mask) {
|
|
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
|
q_seq_idx * mask_q_seq_stride;
|
|
}
|
|
if (float_mask) {
|
|
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
|
q_seq_idx * mask_q_seq_stride;
|
|
}
|
|
|
|
out += o_offset * V + simd_gid * v_per_thread;
|
|
|
|
// Read the query and 0 the output accumulator
|
|
for (int i = 0; i < qk_per_thread; i++) {
|
|
q[i] = static_cast<U>(scale) * queries[i];
|
|
}
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
o[i] = 0;
|
|
}
|
|
|
|
U max_score = -INFINITY;
|
|
U sum_exp_score = 0;
|
|
|
|
// Read the first key
|
|
short a = 0, b = 1;
|
|
for (int j = 0; j < qk_per_thread; j++) {
|
|
k[a][j] = keys[j];
|
|
}
|
|
keys += inner_k_stride;
|
|
|
|
// For each key
|
|
for (int i = simd_gid; i < N; i += BN) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Read the next key
|
|
for (int j = 0; j < qk_per_thread; j++) {
|
|
k[b][j] = keys[j];
|
|
}
|
|
|
|
bool use_key = true;
|
|
if (do_causal) {
|
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
|
} else if (bool_mask) {
|
|
use_key = bmask[0];
|
|
}
|
|
if (use_key) {
|
|
// Compute the i-th score
|
|
U score = 0;
|
|
for (int j = 0; j < qk_per_thread; j++) {
|
|
score += q[j] * k[a][j];
|
|
}
|
|
score = simd_sum(score);
|
|
if (float_mask) {
|
|
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
|
}
|
|
|
|
// Update the accumulators
|
|
U new_max = max(max_score, score);
|
|
U factor = fast::exp(max_score - new_max);
|
|
U exp_score = fast::exp(score - new_max);
|
|
|
|
max_score = new_max;
|
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
|
|
|
// Update the output accumulator
|
|
for (int j = 0; j < v_per_thread; j++) {
|
|
o[j] = o[j] * factor + exp_score * values[j];
|
|
}
|
|
}
|
|
|
|
// Move the pointers to the next kv
|
|
keys += inner_k_stride;
|
|
values += inner_v_stride;
|
|
if (bool_mask) {
|
|
bmask += BN * mask_kv_seq_stride;
|
|
}
|
|
if (float_mask) {
|
|
fmask += BN * mask_kv_seq_stride;
|
|
}
|
|
|
|
// Swap read and write locations
|
|
short tmp = a;
|
|
b = a;
|
|
a = tmp;
|
|
}
|
|
|
|
// Each thread has a partial part of the output so we need to combine them.
|
|
|
|
// First let's communicate the max and sum_exp
|
|
if (simd_lid == 0) {
|
|
max_scores[simd_gid] = max_score;
|
|
sum_exp_scores[simd_gid] = sum_exp_score;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
max_score = max_scores[simd_lid];
|
|
U new_max = simd_max(max_score);
|
|
U factor = fast::exp(max_score - new_max);
|
|
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
|
|
|
// Now we need to aggregate all the outputs
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// And write the output
|
|
if (simd_lid == 0) {
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
out[i] = static_cast<T>(o[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, int D, int V = D>
|
|
[[kernel]] void sdpa_vector_2pass_1(
|
|
const device T* queries [[buffer(0)]],
|
|
const device T* keys [[buffer(1)]],
|
|
const device T* values [[buffer(2)]],
|
|
device float* out [[buffer(3)]],
|
|
device float* sums [[buffer(4)]],
|
|
device float* maxs [[buffer(5)]],
|
|
const constant int& gqa_factor [[buffer(6)]],
|
|
const constant int& N [[buffer(7)]],
|
|
const constant size_t& k_head_stride [[buffer(8)]],
|
|
const constant size_t& k_seq_stride [[buffer(9)]],
|
|
const constant size_t& v_head_stride [[buffer(10)]],
|
|
const constant size_t& v_seq_stride [[buffer(11)]],
|
|
const constant float& scale [[buffer(12)]],
|
|
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
|
|
const device T* fmask [[buffer(14), function_constant(float_mask)]],
|
|
const constant int& mask_kv_seq_stride
|
|
[[buffer(15), function_constant(has_mask)]],
|
|
const constant int& mask_q_seq_stride
|
|
[[buffer(16), function_constant(has_mask)]],
|
|
const constant int& mask_head_stride
|
|
[[buffer(17), function_constant(has_mask)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 tpg [[threadgroups_per_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int BN = 8;
|
|
constexpr int BD = 32;
|
|
constexpr int qk_per_thread = D / BD;
|
|
constexpr int v_per_thread = V / BD;
|
|
int inner_k_stride = BN * int(k_seq_stride);
|
|
int inner_v_stride = BN * int(v_seq_stride);
|
|
constexpr int blocks = 32;
|
|
|
|
typedef float U;
|
|
|
|
thread U q[qk_per_thread];
|
|
thread U k[2][qk_per_thread];
|
|
thread U o[v_per_thread];
|
|
|
|
threadgroup U outputs[BN * BD];
|
|
threadgroup U max_scores[BN];
|
|
threadgroup U sum_exp_scores[BN];
|
|
|
|
// Adjust positions
|
|
const int block_idx = tid.z;
|
|
const int head_idx = tid.x;
|
|
const int q_seq_idx = tid.y;
|
|
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
|
const int q_offset =
|
|
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
|
const int kv_head_idx = head_idx / gqa_factor;
|
|
|
|
queries += q_offset * D + simd_lid * qk_per_thread;
|
|
keys += kv_head_idx * k_head_stride +
|
|
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
|
|
values += kv_head_idx * v_head_stride +
|
|
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
|
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
|
if (bool_mask) {
|
|
bmask += head_idx * mask_head_stride +
|
|
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
|
q_seq_idx * mask_q_seq_stride;
|
|
}
|
|
if (float_mask) {
|
|
fmask += head_idx * mask_head_stride +
|
|
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
|
q_seq_idx * mask_q_seq_stride;
|
|
}
|
|
sums += o_offset * blocks + block_idx;
|
|
maxs += o_offset * blocks + block_idx;
|
|
|
|
// Read the query and 0 the output accumulator
|
|
for (int i = 0; i < qk_per_thread; i++) {
|
|
q[i] = static_cast<U>(scale) * queries[i];
|
|
}
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
o[i] = 0;
|
|
}
|
|
|
|
U max_score = -1e9;
|
|
U sum_exp_score = 0;
|
|
|
|
// Read the first key
|
|
short a = 0, b = 1;
|
|
for (int j = 0; j < qk_per_thread; j++) {
|
|
k[a][j] = keys[j];
|
|
}
|
|
keys += inner_k_stride;
|
|
|
|
// For each key
|
|
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
// Read the next key
|
|
for (int j = 0; j < qk_per_thread; j++) {
|
|
k[b][j] = keys[j];
|
|
}
|
|
|
|
bool use_key = true;
|
|
if (do_causal) {
|
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
|
} else if (bool_mask) {
|
|
use_key = bmask[0];
|
|
}
|
|
if (use_key) {
|
|
// Compute the i-th score
|
|
U score = 0;
|
|
for (int i = 0; i < qk_per_thread; i++) {
|
|
score += q[i] * k[a][i];
|
|
}
|
|
score = simd_sum(score);
|
|
if (float_mask) {
|
|
score += fmask[0];
|
|
}
|
|
|
|
// Update the accumulators
|
|
U new_max = max(max_score, score);
|
|
U factor = fast::exp(max_score - new_max);
|
|
U exp_score = fast::exp(score - new_max);
|
|
|
|
max_score = new_max;
|
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
|
|
|
// Update the output accumulator
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
o[i] = o[i] * factor + exp_score * values[i];
|
|
}
|
|
}
|
|
|
|
// Move the pointers to the next kv
|
|
keys += blocks * inner_k_stride;
|
|
values += blocks * inner_v_stride;
|
|
if (bool_mask) {
|
|
bmask += BN * blocks * mask_kv_seq_stride;
|
|
}
|
|
if (float_mask) {
|
|
fmask += BN * blocks * mask_kv_seq_stride;
|
|
}
|
|
|
|
// Swap read and write locations
|
|
short tmp = a;
|
|
b = a;
|
|
a = tmp;
|
|
}
|
|
|
|
// Each thread has a partial part of the output so we need to combine them.
|
|
|
|
// First let's communicate the max and sum_exp
|
|
if (simd_lid == 0) {
|
|
max_scores[simd_gid] = max_score;
|
|
sum_exp_scores[simd_gid] = sum_exp_score;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
|
U new_max = simd_max(max_score);
|
|
U factor = fast::exp(max_score - new_max);
|
|
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
|
sum_exp_score = simd_sum(sum_exp_score * factor);
|
|
|
|
// Write the sum and new max
|
|
if (simd_gid == 0) {
|
|
sums[0] = sum_exp_score;
|
|
maxs[0] = new_max;
|
|
}
|
|
|
|
// Now we need to aggregate all the outputs
|
|
for (int i = 0; i < v_per_thread; i++) {
|
|
outputs[simd_lid * BN + simd_gid] =
|
|
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// And write the output
|
|
if (simd_gid == 0) {
|
|
U output = outputs[simd_lid * BN];
|
|
for (int j = 1; j < BN; j++) {
|
|
output += outputs[simd_lid * BN + j];
|
|
}
|
|
out[i] = static_cast<T>(output);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
}
|
|
|
|
template <typename T, int D>
|
|
[[kernel]] void sdpa_vector_2pass_2(
|
|
const device float* partials [[buffer(0)]],
|
|
const device float* sums [[buffer(1)]],
|
|
const device float* maxs [[buffer(2)]],
|
|
device T* out [[buffer(3)]],
|
|
uint3 tid [[threadgroup_position_in_grid]],
|
|
uint3 tpg [[threadgroups_per_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
constexpr int BN = 32;
|
|
constexpr int BD = 32;
|
|
constexpr int elem_per_thread = D / BD;
|
|
constexpr int blocks = 32;
|
|
|
|
typedef float U;
|
|
|
|
thread U o[elem_per_thread];
|
|
threadgroup U outputs[BN * BD];
|
|
|
|
// Adjust positions
|
|
const int head_idx = tid.x;
|
|
const int q_seq_idx = tid.y;
|
|
const int n_heads = tpg.x;
|
|
const int q_offset = n_heads * q_seq_idx + head_idx;
|
|
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
|
sums += q_offset * blocks;
|
|
maxs += q_offset * blocks;
|
|
out += q_offset * D + simd_gid * elem_per_thread;
|
|
|
|
// First everybody reads the max and sum_exp
|
|
U max_score = maxs[simd_lid];
|
|
U new_max = simd_max(max_score);
|
|
U factor = fast::exp(max_score - new_max);
|
|
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
|
|
|
|
// Now read the block into registers and then use shared memory to transpose
|
|
// it
|
|
for (int i = 0; i < elem_per_thread; i++) {
|
|
o[i] = partials[i];
|
|
}
|
|
for (int i = 0; i < elem_per_thread; i++) {
|
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// And write the output
|
|
if (simd_lid == 0) {
|
|
for (int i = 0; i < elem_per_thread; i++) {
|
|
out[i] = static_cast<T>(o[i]);
|
|
}
|
|
}
|
|
}
|