From 073076ac7dce7cc36d4753d8cbb7bdbdf17a0f68 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 18 Nov 2024 17:31:53 -0800 Subject: [PATCH] 2-Pass Sdpa Inference Kernel (#1597) --- benchmarks/python/sdpa_vector_bench.py | 45 +++-- .../scaled_dot_product_attention.metal | 19 +- mlx/backend/metal/kernels/sdpa_vector.h | 182 +++++++++++++++++- .../metal/scaled_dot_product_attention.cpp | 101 +++++++++- 4 files changed, 310 insertions(+), 37 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 058f495dc..291f87203 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -4,42 +4,51 @@ import math import mlx.core as mx from time_utils import time_fn -L = 1024 +L = 16384 H = 32 -H_k = 32 // 4 +H_k = H // 4 D = 128 +dtype = mx.float16 +loops = 10 def attention(q, k, v): - B, Hq, L, D = q.shape - _, Hk, S, _ = k.shape - q = q.reshape(B, Hk, Hq // Hk, L, D) - k = k[:, :, None, :, :] - v = v[:, :, None, :, :] - s = q @ k.transpose(0, 1, 2, 4, 3) - p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) - o = p @ v - return o.reshape(B, Hq, L, D) + def _sdpa(q, k, v): + B, Hq, L, D = q.shape + _, Hk, S, _ = k.shape + q = q.reshape(B, Hk, Hq // Hk, L, D) + k = k[:, :, None, :, :] + v = v[:, :, None, :, :] + s = q @ k.transpose(0, 1, 2, 4, 3) + p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) + o = p @ v + return o.reshape(B, Hq, L, D) + + for i in range(loops): + q = _sdpa(q, k, v) + return q def sdpa(q, k, v): - return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + for i in range(loops): + q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + return q def time_self_attention_primitives(): mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)) - k = mx.random.uniform(shape=(1, H_k, L, D)) - v = mx.random.uniform(shape=(1, H_k, L, D)) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) mx.eval(q, k, v) time_fn(attention, q, k, v) def time_self_attention_sdpa(): mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)) - k = mx.random.uniform(shape=(1, H_k, L, D)) - v = mx.random.uniform(shape=(1, H_k, L, D)) + q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) + k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) + v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) mx.eval(q, k, v) time_fn(sdpa, q, k, v) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 478bf2207..014231285 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -926,21 +926,10 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); // SDPA vector instantiations -#define instantiate_sdpa_vector(type, head_dim) \ - template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ - [[kernel]] void sdpa_vector( \ - const device type* queries [[buffer(0)]], \ - const device type* keys [[buffer(1)]], \ - const device type* values [[buffer(2)]], \ - device type* out [[buffer(3)]], \ - const constant int& gqa_factor, \ - const constant int& N, \ - const constant size_t& k_stride, \ - const constant size_t& v_stride, \ - const constant float& scale, \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_sdpa_vector(type, head_dim) \ + instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ + instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \ + instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim) #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 64) \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 5ef316811..8b6af638e 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -21,8 +21,7 @@ template constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - - const int stride = BN * D; + constexpr int stride = BN * D; typedef float U; @@ -84,7 +83,6 @@ template keys += stride; values += stride; } - threadgroup_barrier(mem_flags::mem_threadgroup); // Each thread has a partial part of the output so we need to combine them. @@ -114,3 +112,181 @@ template } } } + +template +[[kernel]] void sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device float* out [[buffer(3)]], + device float* sums [[buffer(4)]], + device float* maxs [[buffer(5)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 8; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int stride = BN * D; + constexpr int blocks = 32; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; + sums += head_idx * blocks + block_idx; + maxs += head_idx * blocks + block_idx; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9; + U sum_exp_score = 0; + + // For each key + for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += blocks * stride; + values += blocks * stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + +template +[[kernel]] void sdpa_vector_2pass_2( + const device float* partials [[buffer(0)]], + const device float* sums [[buffer(1)]], + const device float* maxs [[buffer(2)]], + device T* out [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int blocks = 32; + + typedef float U; + + thread U o[elem_per_thread]; + threadgroup U outputs[BN * BD]; + + // Adjust positions + const int head_idx = tid.y; + partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += head_idx * blocks; + maxs += head_idx * blocks; + out += head_idx * D + simd_gid * elem_per_thread; + + // First everybody reads the max and sum_exp + U max_score = maxs[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + U sum_exp_score = simd_sum(sums[simd_lid] * factor); + + // Now read the block into registers and then use shared memory to transpose + // it + for (int i = 0; i < elem_per_thread; i++) { + o[i] = partials[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index acc20b323..3edda16a3 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/fast_primitives.h" +#include "mlx/utils.h" namespace mlx::core::fast { @@ -184,6 +185,94 @@ void sdpa_vector( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void sdpa_vector_2pass( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + array& out, + float scale) { + // Set the kernel name + std::string kname; + kname.reserve(64); + kname += "sdpa_vector_2pass_1_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int blocks = 32; + int B = q.shape(0) * q.shape(1); + size_t k_stride = k.strides()[1]; + size_t v_stride = v.strides()[1]; + MTL::Size group_dims(8 * 32, 1, 1); + MTL::Size grid_dims(1, B, blocks); + + // Allocate the intermediates + std::vector intermediate_shape; + intermediate_shape.reserve(out.ndim() + 1); + intermediate_shape.insert( + intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); + intermediate_shape.push_back(blocks); + intermediate_shape.push_back(out.shape().back()); + array intermediate(intermediate_shape, float32, nullptr, {}); + intermediate_shape.pop_back(); + array sums(intermediate_shape, float32, nullptr, {}); + array maxs(std::move(intermediate_shape), float32, nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + sums.set_data(allocator::malloc_or_wait(sums.nbytes())); + maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); + d.add_temporary(intermediate, s.index); + d.add_temporary(sums, s.index); + d.add_temporary(maxs, s.index); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_output_array(intermediate, 3); + compute_encoder.set_output_array(sums, 4); + compute_encoder.set_output_array(maxs, 5); + compute_encoder.set_bytes(gqa_factor, 6); + compute_encoder.set_bytes(N, 7); + compute_encoder.set_bytes(k_stride, 8); + compute_encoder.set_bytes(v_stride, 9); + compute_encoder.set_bytes(scale, 10); + + // Launch + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Final pass + kname.clear(); + kname += "sdpa_vector_2pass_2_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + // Get the kernel + kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set its arguments + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_input_array(sums, 1); + compute_encoder.set_input_array(maxs, 2); + compute_encoder.set_output_array(out, 3); + + // Launch + group_dims = MTL::Size(1024, 1, 1); + grid_dims = MTL::Size(1, B, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + } // namespace void ScaledDotProductAttention::eval_gpu( @@ -249,7 +338,17 @@ void ScaledDotProductAttention::eval_gpu( } else { o.set_data(allocator::malloc_or_wait(o.nbytes())); } - sdpa_vector(s, d, q, k, v, o, scale_); + + // We route to the 2 pass fused attention if + // - The device is large and the sequence length long + // - The sequence length is even longer and we have gqa + char devc = d.get_architecture().back(); + if ((devc == 'd' && k.shape(2) >= 1024) || + (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { + sdpa_vector_2pass(s, d, q, k, v, o, scale_); + } else { + sdpa_vector(s, d, q, k, v, o, scale_); + } } // Full attention mode