mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 12:13:21 +08:00
revert sdpa
This commit is contained in:
parent
047a584e3d
commit
6649244686
@ -16,11 +16,9 @@ template <typename T, int D>
|
|||||||
const constant float& scale,
|
const constant float& scale,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]],
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
||||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
constexpr int BD = 4;
|
constexpr int BD = 32;
|
||||||
constexpr int elem_per_thread = D / BD;
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
|
||||||
const int stride = BN * D;
|
const int stride = BN * D;
|
||||||
@ -38,9 +36,9 @@ template <typename T, int D>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.y;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + quad_lid * elem_per_thread;
|
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||||
keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
|
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
|
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
out += head_idx * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
@ -55,7 +53,7 @@ template <typename T, int D>
|
|||||||
U sum_exp_score = 0;
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
// For each key
|
// For each key
|
||||||
for (int i = quad_gid; i < N; i += BN) {
|
for (int i = simd_gid; i < N; i += BN) {
|
||||||
// Read the key
|
// Read the key
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
k[i] = keys[i];
|
k[i] = keys[i];
|
||||||
@ -66,7 +64,7 @@ template <typename T, int D>
|
|||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
score += q[i] * k[i];
|
score += q[i] * k[i];
|
||||||
}
|
}
|
||||||
score = quad_sum(score);
|
score = simd_sum(score);
|
||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
@ -90,10 +88,9 @@ template <typename T, int D>
|
|||||||
// Each thread has a partial part of the output so we need to combine them.
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
|
|
||||||
// First let's communicate the max and sum_exp
|
// First let's communicate the max and sum_exp
|
||||||
// Each quadgroup communicates it's max score
|
if (simd_lid == 0) {
|
||||||
if (quad_lid == 0) {
|
max_scores[simd_gid] = max_score;
|
||||||
max_scores[quad_gid] = max_score;
|
sum_exp_scores[simd_gid] = sum_exp_score;
|
||||||
sum_exp_scores[quad_gid] = sum_exp_score;
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
max_score = max_scores[simd_lid];
|
max_score = max_scores[simd_lid];
|
||||||
@ -103,10 +100,9 @@ template <typename T, int D>
|
|||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now we need to aggregate all the outputs
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
// 128 threads with 32 values per thread
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||||
outputs[simd_gid * BN + simd_lid] = o[i];
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
|
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ void sdpa_vector(
|
|||||||
int N = k.shape(2);
|
int N = k.shape(2);
|
||||||
int B = q.shape(0) * q.shape(1);
|
int B = q.shape(0) * q.shape(1);
|
||||||
size_t stride = k.strides()[1];
|
size_t stride = k.strides()[1];
|
||||||
MTL::Size group_dims(128, 1, 1);
|
MTL::Size group_dims(1024, 1, 1);
|
||||||
MTL::Size grid_dims(1, B, 1);
|
MTL::Size grid_dims(1, B, 1);
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
|
Loading…
Reference in New Issue
Block a user