revert sdpa

This commit is contained in:
Alex Barron 2024-10-22 20:10:36 -07:00
parent 047a584e3d
commit 6649244686
2 changed files with 13 additions and 17 deletions

View File

@ -16,11 +16,9 @@ template <typename T, int D>
const constant float& scale, const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]], uint simd_lid [[thread_index_in_simdgroup]]) {
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 4; constexpr int BD = 32;
constexpr int elem_per_thread = D / BD; constexpr int elem_per_thread = D / BD;
const int stride = BN * D; const int stride = BN * D;
@ -38,9 +36,9 @@ template <typename T, int D>
// Adjust positions // Adjust positions
const int head_idx = tid.y; const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + quad_lid * elem_per_thread; queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
@ -55,7 +53,7 @@ template <typename T, int D>
U sum_exp_score = 0; U sum_exp_score = 0;
// For each key // For each key
for (int i = quad_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
// Read the key // Read the key
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
@ -66,7 +64,7 @@ template <typename T, int D>
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = quad_sum(score); score = simd_sum(score);
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@ -90,10 +88,9 @@ template <typename T, int D>
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp // First let's communicate the max and sum_exp
// Each quadgroup communicates it's max score if (simd_lid == 0) {
if (quad_lid == 0) { max_scores[simd_gid] = max_score;
max_scores[quad_gid] = max_score; sum_exp_scores[simd_gid] = sum_exp_score;
sum_exp_scores[quad_gid] = sum_exp_score;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid]; max_score = max_scores[simd_lid];
@ -103,10 +100,9 @@ template <typename T, int D>
// Now we need to aggregate all the outputs // Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
// 128 threads with 32 values per thread outputs[simd_lid * BD + simd_gid] = o[i];
outputs[simd_gid * BN + simd_lid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }

View File

@ -165,7 +165,7 @@ void sdpa_vector(
int N = k.shape(2); int N = k.shape(2);
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1]; size_t stride = k.strides()[1];
MTL::Size group_dims(128, 1, 1); MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1); MTL::Size grid_dims(1, B, 1);
// Get the kernel // Get the kernel