// Copyright © 2024 Apple Inc. #include using namespace metal; constant bool has_mask [[function_constant(20)]]; constant bool query_transposed [[function_constant(21)]]; constant bool do_causal [[function_constant(22)]]; constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; template [[kernel]] void sdpa_vector( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device T* out [[buffer(3)]], const constant int& gqa_factor [[buffer(4)]], const constant int& N [[buffer(5)]], const constant size_t& k_head_stride [[buffer(6)]], const constant size_t& k_seq_stride [[buffer(7)]], const constant size_t& v_head_stride [[buffer(8)]], const constant size_t& v_seq_stride [[buffer(9)]], const constant float& scale [[buffer(10)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]], const constant int& mask_kv_seq_stride [[buffer(13), function_constant(has_mask)]], const constant int& mask_q_seq_stride [[buffer(14), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(15), function_constant(has_mask)]], const device T* sinks [[buffer(16), function_constant(has_sinks)]], const constant int& num_q_heads [[buffer(17), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; int inner_k_stride = BN * int(k_seq_stride); int inner_v_stride = BN * int(v_seq_stride); typedef float U; thread U q[qk_per_thread]; thread U k[qk_per_thread]; thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; threadgroup U max_scores[BN]; threadgroup U sum_exp_scores[BN]; // Adjust positions const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = q_batch_head_idx / gqa_factor; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } out += o_offset * V + simd_gid * v_per_thread; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } for (int i = 0; i < v_per_thread; i++) { o[i] = 0; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && simd_gid == 0) { max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); sum_exp_score = 1; } // For each key for (int i = simd_gid; i < N; i += BN) { bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key for (int j = 0; j < qk_per_thread; j++) { k[j] = keys[j]; } // Compute the i-th score U score = 0; for (int j = 0; j < qk_per_thread; j++) { score += q[j] * k[j]; } score = simd_sum(score); if (float_mask) { score += static_cast(fmask[0]); } // Update the accumulators U new_max = max(max_score, score); U factor = fast::exp(max_score - new_max); U exp_score = fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator for (int j = 0; j < v_per_thread; j++) { o[j] = o[j] * factor + exp_score * values[j]; } } // Move the pointers to the next kv keys += inner_k_stride; values += inner_v_stride; if (bool_mask) { bmask += BN * mask_kv_seq_stride; } if (float_mask) { fmask += BN * mask_kv_seq_stride; } } // Each thread has a partial part of the output so we need to combine them. // First let's communicate the max and sum_exp if (simd_lid == 0) { max_scores[simd_gid] = max_score; sum_exp_scores[simd_gid] = sum_exp_score; } threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; U new_max = simd_max(max_score); U factor = fast::exp(max_score - new_max); sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); // Now we need to aggregate all the outputs for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } // And write the output if (simd_lid == 0) { for (int i = 0; i < v_per_thread; i++) { out[i] = static_cast(o[i]); } } } template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device float* out [[buffer(3)]], device float* sums [[buffer(4)]], device float* maxs [[buffer(5)]], const constant int& gqa_factor [[buffer(6)]], const constant int& N [[buffer(7)]], const constant size_t& k_head_stride [[buffer(8)]], const constant size_t& k_seq_stride [[buffer(9)]], const constant size_t& v_head_stride [[buffer(10)]], const constant size_t& v_seq_stride [[buffer(11)]], const constant float& scale [[buffer(12)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]], const constant int& mask_kv_seq_stride [[buffer(15), function_constant(has_mask)]], const constant int& mask_q_seq_stride [[buffer(16), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], const device T* sinks [[buffer(18), function_constant(has_sinks)]], const constant int& num_q_heads [[buffer(19), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 8; constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; int inner_k_stride = BN * int(k_seq_stride); int inner_v_stride = BN * int(v_seq_stride); constexpr int blocks = 32; typedef float U; thread U q[qk_per_thread]; thread U k[qk_per_thread]; thread U o[v_per_thread]; threadgroup U outputs[BN * BD]; threadgroup U max_scores[BN]; threadgroup U sum_exp_scores[BN]; // Adjust positions const int block_idx = tid.z; const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; const int kv_head_idx = q_batch_head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; // Read the query and 0 the output accumulator for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } for (int i = 0; i < v_per_thread; i++) { o[i] = 0; } U max_score = Limits::finite_min; U sum_exp_score = 0; if (has_sinks && block_idx == 0 && simd_gid == 0) { int q_head_idx = q_batch_head_idx % num_q_heads; max_score = static_cast(sinks[q_head_idx]); sum_exp_score = 1; } // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key for (int i = 0; i < qk_per_thread; i++) { k[i] = keys[i]; } // Compute the i-th score U score = 0; for (int i = 0; i < qk_per_thread; i++) { score += q[i] * k[i]; } score = simd_sum(score); if (float_mask) { score += fmask[0]; } // Update the accumulators U new_max = max(max_score, score); U factor = fast::exp(max_score - new_max); U exp_score = fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; // Update the output accumulator for (int i = 0; i < v_per_thread; i++) { o[i] = o[i] * factor + exp_score * values[i]; } } // Move the pointers to the next kv keys += blocks * inner_k_stride; values += blocks * inner_v_stride; if (bool_mask) { bmask += BN * blocks * mask_kv_seq_stride; } if (float_mask) { fmask += BN * blocks * mask_kv_seq_stride; } } // Each thread has a partial part of the output so we need to combine them. // First let's communicate the max and sum_exp if (simd_lid == 0) { max_scores[simd_gid] = max_score; sum_exp_scores[simd_gid] = sum_exp_score; } threadgroup_barrier(mem_flags::mem_threadgroup); max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; U new_max = simd_max(max_score); U factor = fast::exp(max_score - new_max); sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; sum_exp_score = simd_sum(sum_exp_score * factor); // Write the sum and new max if (simd_gid == 0) { sums[0] = sum_exp_score; maxs[0] = new_max; } // Now we need to aggregate all the outputs for (int i = 0; i < v_per_thread; i++) { outputs[simd_lid * BN + simd_gid] = o[i] * fast::exp(max_scores[simd_gid] - new_max); threadgroup_barrier(mem_flags::mem_threadgroup); // And write the output if (simd_gid == 0) { U output = outputs[simd_lid * BN]; for (int j = 1; j < BN; j++) { output += outputs[simd_lid * BN + j]; } out[i] = static_cast(output); } threadgroup_barrier(mem_flags::mem_threadgroup); } } template [[kernel]] void sdpa_vector_2pass_2( const device float* partials [[buffer(0)]], const device float* sums [[buffer(1)]], const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; constexpr int blocks = 32; typedef float U; thread U o[elem_per_thread]; threadgroup U outputs[BN * BD]; // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; const int q_offset = head_idx * tpg.y + q_seq_idx; ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; out += q_offset * D + simd_gid * elem_per_thread; // First everybody reads the max and sum_exp U max_score = maxs[simd_lid]; U new_max = simd_max(max_score); U factor = fast::exp(max_score - new_max); U sum_exp_score = simd_sum(sums[simd_lid] * factor); // Now read the block into registers and then use shared memory to transpose // it for (int i = 0; i < elem_per_thread; i++) { o[i] = partials[i]; } for (int i = 0; i < elem_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } // And write the output if (simd_lid == 0) { for (int i = 0; i < elem_per_thread; i++) { out[i] = static_cast(o[i]); } } }