diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 38b75413d..2a1cae6db 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -16,11 +16,9 @@ template const constant float& scale, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; - constexpr int BD = 4; + constexpr int BD = 32; constexpr int elem_per_thread = D / BD; const int stride = BN * D; @@ -38,9 +36,9 @@ template // Adjust positions const int head_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + quad_lid * elem_per_thread; - keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; - values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -55,7 +53,7 @@ template U sum_exp_score = 0; // For each key - for (int i = quad_gid; i < N; i += BN) { + for (int i = simd_gid; i < N; i += BN) { // Read the key for (int i = 0; i < elem_per_thread; i++) { k[i] = keys[i]; @@ -66,7 +64,7 @@ template for (int i = 0; i < elem_per_thread; i++) { score += q[i] * k[i]; } - score = quad_sum(score); + score = simd_sum(score); // Update the accumulators U new_max = max(max_score, score); @@ -90,10 +88,9 @@ template // Each thread has a partial part of the output so we need to combine them. // First let's communicate the max and sum_exp - // Each quadgroup communicates it's max score - if (quad_lid == 0) { - max_scores[quad_gid] = max_score; - sum_exp_scores[quad_gid] = sum_exp_score; + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; } threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; @@ -103,10 +100,9 @@ template // Now we need to aggregate all the outputs for (int i = 0; i < elem_per_thread; i++) { - // 128 threads with 32 values per thread - outputs[simd_gid * BN + simd_lid] = o[i]; + outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; threadgroup_barrier(mem_flags::mem_threadgroup); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c5251556f..6643c380d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -165,7 +165,7 @@ void sdpa_vector( int N = k.shape(2); int B = q.shape(0) * q.shape(1); size_t stride = k.strides()[1]; - MTL::Size group_dims(128, 1, 1); + MTL::Size group_dims(1024, 1, 1); MTL::Size grid_dims(1, B, 1); // Get the kernel