diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py new file mode 100644 index 000000000..058f495dc --- /dev/null +++ b/benchmarks/python/sdpa_vector_bench.py @@ -0,0 +1,49 @@ +import argparse +import math + +import mlx.core as mx +from time_utils import time_fn + +L = 1024 +H = 32 +H_k = 32 // 4 +D = 128 + + +def attention(q, k, v): + B, Hq, L, D = q.shape + _, Hk, S, _ = k.shape + q = q.reshape(B, Hk, Hq // Hk, L, D) + k = k[:, :, None, :, :] + v = v[:, :, None, :, :] + s = q @ k.transpose(0, 1, 2, 4, 3) + p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) + o = p @ v + return o.reshape(B, Hq, L, D) + + +def sdpa(q, k, v): + return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + + +def time_self_attention_primitives(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)) + k = mx.random.uniform(shape=(1, H_k, L, D)) + v = mx.random.uniform(shape=(1, H_k, L, D)) + mx.eval(q, k, v) + time_fn(attention, q, k, v) + + +def time_self_attention_sdpa(): + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 1, D)) + k = mx.random.uniform(shape=(1, H_k, L, D)) + v = mx.random.uniform(shape=(1, H_k, L, D)) + mx.eval(q, k, v) + time_fn(sdpa, q, k, v) + + +if __name__ == "__main__": + time_self_attention_sdpa() + time_self_attention_primitives() diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4d637a154..28f34535c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -30,8 +30,9 @@ build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) -build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h - steel/defines.h steel/gemm/transforms.h steel/utils.h) +build_kernel( + scaled_dot_product_attention scaled_dot_product_attention_params.h + sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h) set(STEEL_HEADERS steel/defines.h diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 8d3506c34..27fbb765b 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,11 +1,11 @@ -#include #include +#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" +#include "mlx/backend/metal/kernels/sdpa_vector.h" #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" -#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/utils.h" -#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" using namespace metal; using namespace mlx::steel; @@ -886,6 +886,9 @@ template < } } +// clang-format off + +// SDPA full instantiations #define instantiate_fast_inference_self_attention_kernel( \ itype, otype, bm, bn, bk, wm, wn) \ template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ @@ -922,548 +925,28 @@ instantiate_fast_inference_self_attention_kernel( instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); -template < - typename T, - typename T2, - typename T4, - uint16_t TILE_SIZE_CONST, - uint16_t NSIMDGROUPS> -[[kernel]] void fast_inference_sdpa_compute_partials_template( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - const device uint64_t& L [[buffer(3)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], - device float* O_partials [[buffer(5)]], - device float* p_lse [[buffer(6)]], - device float* p_maxes [[buffer(7)]], - threadgroup T* threadgroup_block [[threadgroup(0)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - constexpr const size_t DK = 128; - constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; - constexpr const size_t THREADS_PER_SIMDGROUP = 32; - constexpr const uint iter_offset = NSIMDGROUPS * 4; - const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; - uint kv_head_offset_factor = tid.x; - if (is_gqa) { - int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; - kv_head_offset_factor = tid.x / q_kv_head_ratio; - } - constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; - constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = - TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); - constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; - constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * - SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * - NSIMDGROUPS; +// SDPA vector instantiations +#define instantiate_sdpa_vector(type, head_dim) \ + template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ + [[kernel]] void sdpa_vector( \ + const device type* queries [[buffer(0)]], \ + const device type* keys [[buffer(1)]], \ + const device type* values [[buffer(2)]], \ + device type* out [[buffer(3)]], \ + const constant int& gqa_factor, \ + const constant int& N, \ + const constant size_t& k_stride, \ + const constant float& scale, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); - threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; -#pragma clang loop unroll(full) - for (uint i = 0; i < 8; i++) { - smemFlush - [simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + - i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - // TODO: multiple query sequence length for speculative decoding - const uint tgroup_query_head_offset = - tid.x * DK + tid.z * (params.N_Q_HEADS * DK); +#define instantiate_sdpa_vector_heads(type) \ + instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 96) \ + instantiate_sdpa_vector(type, 128) - const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; - const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; - - const device T* baseK = - K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; - const device T* baseQ = Q + tgroup_query_head_offset; - - device T4* simdgroupQueryData = (device T4*)baseQ; - - constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; - float threadAccum[ACCUM_PER_GROUP]; - -#pragma clang loop unroll(full) - for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; - threadAccumIndex++) { - threadAccum[threadAccumIndex] = -INFINITY; - } - - uint KROW_ACCUM_INDEX = 0; - - const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; - const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; - const bool LAST_TILE_ALIGNED = - (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); - - T4 thread_data_x4; - T4 thread_data_y4; - if (!LAST_TILE || LAST_TILE_ALIGNED) { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); -#pragma clang loop unroll(full) - for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; - KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseK + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; - } - } else { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); - const uint START_ROW = tid.y * TILE_SIZE_CONST; - const device T* baseKThisHead = - K + tgroup_k_batch_offset + tgroup_k_head_offset; - - for (size_t KROW = START_ROW + simd_group_id; KROW < L; - KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseKThisHead + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; - } - } - threadgroup float* smemP = (threadgroup float*)threadgroup_block; - -#pragma clang loop unroll(full) - for (size_t i = 0; i < P_VEC4; i++) { - thread_data_x4 = - T4(threadAccum[4 * i], - threadAccum[4 * i + 1], - threadAccum[4 * i + 2], - threadAccum[4 * i + 3]); - simdgroup_barrier(mem_flags::mem_none); - thread_data_y4 = simd_sum(thread_data_x4); - if (simd_lane_id == 0) { - const uint base_smem_p_offset = i * iter_offset + simd_group_id; - smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); - smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); - smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); - smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - float groupMax; - float lse = 0.f; - - constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; - constexpr const size_t ACCUM_ARRAY_LENGTH = - TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; - float4 pvals[ACCUM_ARRAY_LENGTH]; - -#pragma clang loop unroll(full) - for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; - accum_array_iter++) { - pvals[accum_array_iter] = float4(-INFINITY); - } - - if (TILE_SIZE_CONST == 64) { - threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; - float2 vals = smemPtrFlt2[simd_lane_id]; - vals *= params.INV_ALPHA; - float maxval = max(vals.x, vals.y); - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); - - float2 expf_shifted = exp(vals - groupMax); - float sumExpLocal = expf_shifted.x + expf_shifted.y; - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - - lse = log(tgroupExpSum); - float2 local_p_hat = expf_shifted / tgroupExpSum; - pvals[0].x = local_p_hat.x; - pvals[0].y = local_p_hat.y; - smemPtrFlt2[simd_lane_id] = float2(0.f); - } - constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; - constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; - - if (TILE_SIZE_LARGER_THAN_64) { - float maxval = -INFINITY; - threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; - vals *= params.INV_ALPHA; - pvals[i] = vals; - maxval = fmax3(vals.x, vals.y, maxval); - maxval = fmax3(vals.z, vals.w, maxval); - } - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); - - float sumExpLocal = 0.f; -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = exp(pvals[i] - groupMax); - sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; - } - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - lse = log(tgroupExpSum); -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = pvals[i] / tgroupExpSum; - smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); - } - } - - threadgroup T* smemV = (threadgroup T*)threadgroup_block; - - const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; - const size_t v_head_offset = kv_head_offset_factor * L * DK; - - const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; - device T* baseV = (device T*)V + v_offset; - - threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); - - if (!LAST_TILE || LAST_TILE_ALIGNED) { -#pragma clang loop unroll(full) - for (size_t col = 0; col < MATRIX_COLS; col++) { - uint matrix_load_loop_iter = 0; - constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; - - for (size_t tile_start = simd_group_id; - tile_start < TILE_SIZE_CONST_DIV_8; - tile_start += NSIMDGROUPS) { - simdgroup_matrix tmp; - ulong simdgroup_matrix_offset = - matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + - simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - ulong2 matrixOrigin = - ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); - simdgroup_load(tmp, baseV, DK, matrixOrigin, true); - const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); - const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); - matrix_load_loop_iter++; - }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - uint loop_iter = 0; - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - -#pragma clang loop unroll(full) - for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; - row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - - T row_sum = simd_sum(val); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = - float(row_sum); - loop_iter++; - } - } - - if (TILE_SIZE_CONST > 64) { - constexpr const size_t TILE_SIZE_CONST_DIV_128 = - (TILE_SIZE_CONST + 1) / 128; - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_iter = 0; - for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; - row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - - T row_sum = 0.f; - for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[i]); - T val = dot(p_local, v_local); - row_sum += val; - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = - float(row_sum); - loop_iter++; - } - } - } - } else { - const int32_t START_ROW = tid.y * TILE_SIZE_CONST; - const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; - const device T* baseVThisHead = V + v_batch_offset + v_head_offset; - constexpr const int ROWS_PER_ITER = 8; -#pragma clang loop unroll(full) - for (size_t col = 0; col < MATRIX_COLS; col++) { - uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - int32_t tile_start; - for (tile_start = - START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - tile_start < MAX_START_ROW; - tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { - simdgroup_matrix tmp; - ulong2 matrixOrigin = - ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); - simdgroup_load( - tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); - const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); - constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store( - tmp, - smemV, - elemsPerRowSmem, - matrixOriginSmem, - /* transpose */ false); - smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; - }; - - tile_start = - ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); - - const int32_t INT_L = int32_t(L); - for (int row_index = tile_start + simd_group_id; row_index < INT_L; - row_index += NSIMDGROUPS) { - if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { - const uint elems_per_row_gmem = DK; - const uint col_index_v_gmem = - col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; - const uint row_index_v_gmem = row_index; - - const uint elems_per_row_smem = TILE_SIZE_CONST; - const uint col_index_v_smem = row_index % TILE_SIZE_CONST; - const uint row_index_v_smem = simd_lane_id; - - const uint scalar_offset_gmem = - row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; - const uint scalar_offset_smem = - row_index_v_smem * elems_per_row_smem + col_index_v_smem; - T vdata = T(*(baseVThisHead + scalar_offset_gmem)); - smemV[scalar_offset_smem] = vdata; - smem_col_index += NSIMDGROUPS; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - for (size_t smem_row_index = simd_group_id; - smem_row_index < ROWS_PER_ITER; - smem_row_index += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - T row_sum = simd_sum(val); - oPartialSmem[smem_row_index] = float(row_sum); - } - } - - if (TILE_SIZE_CONST > 64) { - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_count = 0; - for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER; - row_index += NSIMDGROUPS) { - T row_sum = 0.f; - for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; - tile_iters++) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = - *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[tile_iters]); - row_sum += dot(p_local, v_local); - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = - float(row_sum); - loop_count++; - } - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (simd_group_id == 0) { - threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; - float4 vals = *(oPartialVec4 + simd_lane_id); - device float* oPartialGmem = - O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; - device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; - oPartialGmemVec4[simd_lane_id] = vals; - } - - if (simd_group_id == 0 && simd_lane_id == 0) { - const uint tileIndex = tid.y; - const uint gmem_partial_scalar_offset = - tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + - tileIndex; - p_lse[gmem_partial_scalar_offset] = lse; - p_maxes[gmem_partial_scalar_offset] = groupMax; - } -} - -#define instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, nsimdgroups) \ - template [[host_name("fast_inference_sdpa_compute_partials_" #itype \ - "_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \ - fast_inference_sdpa_compute_partials_template< \ - itype, \ - itype2, \ - itype4, \ - tile_size, \ - nsimdgroups>( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - const device uint64_t& L [[buffer(3)]], \ - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ - device float* O_partials [[buffer(5)]], \ - device float* p_lse [[buffer(6)]], \ - device float* p_maxes [[buffer(7)]], \ - threadgroup itype* threadgroup_block [[threadgroup(0)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]]); - -// clang-format off -#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \ - itype, itype2, itype4, tile_size) \ - instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, 4) \ - instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, 8) // clang-format on - -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 512); - -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 512); - -template -void fast_inference_sdpa_reduce_tiles_template( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device T* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - constexpr const int DK = 128; - const ulong offset_rows = - tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; - const device float* p_lse_row = p_lse + offset_rows; - const device float* p_rowmax_row = p_maxes + offset_rows; - // reserve some number of registers. this constitutes an assumption on max - // value of KV TILES. - constexpr const uint8_t reserve = 128; - float p_lse_regs[reserve]; - float p_rowmax_regs[reserve]; - float weights[reserve]; - - float true_max = -INFINITY; - for (size_t i = 0; i < params.KV_TILES; i++) { - p_lse_regs[i] = float(*(p_lse_row + i)); - p_rowmax_regs[i] = float(*(p_rowmax_row + i)); - true_max = fmax(p_rowmax_regs[i], true_max); - weights[i] = exp(p_lse_regs[i]); - } - - float denom = 0.f; - for (size_t i = 0; i < params.KV_TILES; i++) { - weights[i] *= exp(p_rowmax_regs[i] - true_max); - denom += weights[i]; - } - - const device float* O_partials_with_offset = O_partials + - tid.z * params.N_Q_HEADS * DK * params.KV_TILES + - tid.x * DK * params.KV_TILES; - - float o_value = 0.f; - for (size_t i = 0; i < params.KV_TILES; i++) { - float val = *(O_partials_with_offset + i * DK + lid.x); - o_value += val * weights[i] / denom; - } - device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; - O_gmem[lid.x] = T(o_value); - return; -} - -kernel void fast_inference_sdpa_reduce_tiles_float( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device float* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - fast_inference_sdpa_reduce_tiles_template( - O_partials, p_lse, p_maxes, params, O, tid, lid); -} - -kernel void fast_inference_sdpa_reduce_tiles_half( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device half* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - fast_inference_sdpa_reduce_tiles_template( - O_partials, p_lse, p_maxes, params, O, tid, lid); -} +instantiate_sdpa_vector_heads(float) +instantiate_sdpa_vector_heads(bfloat16_t) +instantiate_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h new file mode 100644 index 000000000..4d4a9180b --- /dev/null +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -0,0 +1,115 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = simd_gid; i < N; i += BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += stride; + values += stride; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 92d3ee05a..31811bb7c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -1,20 +1,13 @@ -// -// scaled_dot_product_attention.cpp -// mlx +// Copyright © 2024 Apple Inc. -#include #include -#include #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" -#include "mlx/backend/metal/metal.h" -#include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" -#include "mlx/primitives.h" -#include "mlx/utils.h" namespace mlx::core::fast { @@ -26,8 +19,7 @@ void sdpa_full_self_attention_metal( const array& k, const array& v, const float alpha, - array& out, - std::vector& temporaries) { + array& out) { std::ostringstream kname_self_attention; kname_self_attention << "steel_gemm_attention_"; @@ -148,130 +140,58 @@ void sdpa_full_self_attention_metal( MTL::Size group_dims = MTL::Size(32, wm, wn); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); - - d.get_command_buffer(s.index)->addCompletedHandler( - [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); - return; } -void sdpa_metal( +void sdpa_vector( const Stream& s, metal::Device& d, const array& q, const array& k, const array& v, - const array& p_lse, - const array& p_rowmaxes, - const array& o_partial, - const uint heads, - const uint tile_size, - const uint n_tiles, - const float alpha, array& out, - std::vector& temporaries) { - std::ostringstream kname_partials; + float scale) { + // Set the kernel name + std::string kname; + kname.reserve(64); + kname += "sdpa_vector_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); - kname_partials << "fast_inference_sdpa_compute_partials_"; + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int B = q.shape(0) * q.shape(1); + size_t stride = k.strides()[1]; + MTL::Size group_dims(1024, 1, 1); + MTL::Size grid_dims(1, B, 1); - std::ostringstream kname_reduce; - std::string delimiter = "_"; - kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter; - - for (const auto& arr : {k, v, out}) { - if (arr.dtype() != q.dtype()) { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o"); - } - } - - if (q.dtype() == float32) { - kname_partials << "float" + delimiter; - kname_reduce << "float"; - } else if (q.dtype() == float16) { - kname_partials << "half" + delimiter; - kname_reduce << "half"; - } else { - throw std::runtime_error( - "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16."); - } - - std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter; - - uint nsimd = 8; - std::string kname_suffix_nsimdgroups = std::to_string(nsimd); - - // maximum number of splits == 128 at the moment (reserved tile registers in - // reduction kernel). this is arbitrary and could be changed in the shader. - - std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups; - kname_partials << kname_suffix; + // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname_partials.str()); + auto kernel = d.get_kernel(kname); compute_encoder->setComputePipelineState(kernel); - constexpr const uint batch = 1; - MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch); - MTL::Size group_dims = MTL::Size(32, nsimd, 1); - - const uint64_t KV_sequence_length = k.shape(-2); - const uint query_sequence_length = q.shape(-2); - const uint n_q_heads = q.shape(1); - const uint n_kv_heads = k.shape(1); - - MLXScaledDotProductAttentionParams params{ - query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha}; - - compute_encoder.set_input_array(q, 0); + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); - compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3); - compute_encoder->setBytes( - ¶ms, sizeof(MLXScaledDotProductAttentionParams), 4); - compute_encoder.set_input_array(o_partial, 5); - compute_encoder.set_input_array(p_lse, 6); - compute_encoder.set_input_array(p_rowmaxes, 7); + compute_encoder.set_output_array(out, 3); + compute_encoder->setBytes(&gqa_factor, sizeof(int), 4); + compute_encoder->setBytes(&N, sizeof(int), 5); + compute_encoder->setBytes(&stride, sizeof(size_t), 6); + compute_encoder->setBytes(&scale, sizeof(float), 7); - constexpr const uint tgroupMemorySize = 32768; - compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0); + // Launch compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - - { - auto kernel_accum = d.get_kernel(kname_reduce.str()); - compute_encoder->setComputePipelineState(kernel_accum); - compute_encoder.set_input_array(o_partial, 0); - compute_encoder.set_input_array(p_lse, 1); - compute_encoder.set_input_array(p_rowmaxes, 2); - compute_encoder->setBytes( - ¶ms, sizeof(MLXScaledDotProductAttentionParams), 3); - compute_encoder.set_output_array(out, 4); - - MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch); - MTL::Size group_dims_reduce = MTL::Size(128, 1, 1); - - compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); - - d.get_command_buffer(s.index)->addCompletedHandler( - [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); - return; - } } + } // namespace void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { - assert(inputs.size() >= 3); - if (!issubdtype(out.dtype(), floating)) { - throw std::runtime_error( - "[ScaledDotProductAttention] Does not yet support non-floating point types."); - } + assert(inputs.size() == 3); - if (inputs.size() == 4) { - out = fallback_(inputs)[0]; - return; - } - - out.set_data(allocator::malloc_or_wait(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -279,84 +199,75 @@ void ScaledDotProductAttention::eval_gpu( auto& k_pre = inputs[1]; auto& v_pre = inputs[2]; auto& o = out; - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - std::vector temporaries; - auto check_transpose = [&temporaries, &s](const array& arr) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { - return arr; - } else { + std::vector copies; + + // Define some copy functions to ensure the layout of the inputs is as + // expected. + auto copy_unless = [&copies, &s](auto predicate, const array& arr) { + if (!predicate(arr)) { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); - temporaries.push_back(arr_copy); - size_t stx = arr.shape(-1); + copies.push_back(arr_copy); return arr_copy; + } else { + return arr; } }; - auto q = check_transpose(q_pre); - auto k = check_transpose(k_pre); - auto v = check_transpose(v_pre); + // Checks if arr is fully row contiguous + auto is_contiguous = [](const array& arr) { + return arr.flags().row_contiguous; + }; - const int heads = q.shape(-3); + // Returns true if the array is row contiguous except the sequence length + // dimension that can be sliced but with step=1. + auto is_contiguous_except_seq_len = [](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return strides[3] == 1 && strides[2] == shape[3] && + strides[0] == strides[1] * shape[1]; + }; - uint query_sequence_length = q.shape(-2); - if (query_sequence_length >= 16) { - return sdpa_full_self_attention_metal( - s, d, q, k, v, scale_, out, temporaries); - } - int tile_size = 64; - const int kv_seq_len = k.shape(-2); - if (kv_seq_len > 8000) { - tile_size = 128; - } - if (kv_seq_len > 16000) { - tile_size = 256; - } - if (kv_seq_len > 32000) { - tile_size = 512; + // Checks that the last two dims are row contiguous. + auto is_matrix_contiguous = [](const array& arr) { + auto& strides = arr.strides(); + auto& shape = arr.shape(); + return strides[3] == 1 && strides[2] == shape[3]; + }; + + // We are in vector mode ie single query + if (q_pre.shape(2) == 1) { + auto q = copy_unless(is_contiguous, q_pre); + auto k = copy_unless(is_contiguous_except_seq_len, k_pre); + auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + + // Donate the query if possible + if (q.is_donatable()) { + o.move_shared_buffer(q); + } else { + o.set_data(allocator::malloc_or_wait(o.nbytes())); + } + + sdpa_vector(s, d, q, k, v, o, scale_); } - const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size; + // Full attention mode + else { + auto q = copy_unless(is_matrix_contiguous, q_pre); + auto k = copy_unless(is_matrix_contiguous, k_pre); + auto v = copy_unless(is_matrix_contiguous, v_pre); + o.set_data(allocator::malloc_or_wait(o.nbytes())); - array o_partials( - {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)}, - float32, - nullptr, - {}); - o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes())); + sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); + } - array p_lse( - {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {}); - array p_rowmaxes( - {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {}); - p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes())); - p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes())); - - temporaries.push_back(p_lse); - temporaries.push_back(p_rowmaxes); - temporaries.push_back(o_partials); - - return sdpa_metal( - s, - d, - q, - k, - v, - p_lse, - p_rowmaxes, - o_partials, - heads, - tile_size, - n_tiles, - scale_, - out, - temporaries); + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies = std::move(copies)](MTL::CommandBuffer*) mutable { + copies.clear(); + }); + } } } // namespace mlx::core::fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 6fd4862d4..d203b5bde 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -618,40 +618,38 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); + const size_t value_head_dim = v.shape(-1); const size_t query_head_dim = q.shape(-1); - const bool supported_head_dim = - query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128; - - const bool supported_head_dim_self_attn = - query_head_dim == 64 || query_head_dim == 128; const size_t query_sequence_length = q.shape(2); - const bool supports_full_self_attention = query_sequence_length >= 16 && - !mask.has_value() && supported_head_dim_self_attn && + + bool implementation_supports_use_case = query_head_dim == value_head_dim; + + const bool sdpa_vector_supported_head_dim = + query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128; + const bool sdpa_full_supported_head_dim = + query_head_dim == 64 || query_head_dim == 128; + + const bool supports_sdpa_full = query_sequence_length >= threshold && + !mask.has_value() && sdpa_full_supported_head_dim && n_q_heads == n_kv_heads && final_type != bfloat16 && stream.device == Device::gpu; - // fast decoding gpu shader - bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 && - !mask.has_value() && supported_head_dim && final_type != bfloat16 && + const bool supports_sdpa_vector = query_sequence_length == 1 && + !mask.has_value() && sdpa_vector_supported_head_dim && stream.device == Device::gpu; - bool implementation_supports_use_case = - supports_sdpa || supports_full_self_attention; - // sdpa gpu shader is disabled except for memory efficient opt-in - const int seq_for_threshold = queries.shape(2); - bool use_memory_efficient_impl = seq_for_threshold >= threshold; - implementation_supports_use_case &= use_memory_efficient_impl; + implementation_supports_use_case &= + supports_sdpa_full || supports_sdpa_vector; if (implementation_supports_use_case) { auto out_shape = std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); - auto out = array( + return array( std::move(out_shape), final_type, std::make_shared( stream, fallback, scale, false), {q, k, v}); - return out; } if (mask.has_value()) {