From 0787724c44b870943386fe97ff709ab535f62c9c Mon Sep 17 00:00:00 2001 From: Brian Keene Date: Tue, 5 Mar 2024 00:06:11 -0500 Subject: [PATCH] Fast Inference SDPA op (#735) * Fast Inference SDPA op Implements metal shaders for: o = mx.fast_inference_sdpa(queries, keys, values, scale, mask) Supports fp16, fp32 dtypes; assumes d_k = 128. Generic op support / prompt encoding supported via mlx primitives. Metal implementation is for the inference use case only. Majority of performance benefits appears to results from GQA & reduced bandwidth requirements; there is approximate performance parity for the MHA use case (from some measurements on M3 Max). * Flush shared memory to zero before unprotected reads for (scores @ values) * Move to fast:: namespace, address reviewer comments ... also attempt to revert formatter auto-change for files not relevant to this change * Shared memory flush to top of kernel * Resolve compiler warnings * Update python/src/fast.cpp Co-authored-by: Awni Hannun * Update python/src/fast.cpp Co-authored-by: Awni Hannun * Update python/src/fast.cpp Co-authored-by: Awni Hannun * Update python/src/fast.cpp Co-authored-by: Awni Hannun * Update docstring per PR feedback * Softmax in higher precision, ... * route to fallback for more use cases - batch size > 1, head_dim other than 128, etc. * Address linux build failure * Address other reviewer comments * Remove extraneous eval_cpu function per review --------- Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com> Co-authored-by: Awni Hannun Co-authored-by: atila --- ACKNOWLEDGMENTS.md | 2 +- mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/kernels/CMakeLists.txt | 1 + .../scaled_dot_product_attention.metal | 451 ++++++++++++++++++ .../scaled_dot_product_attention_params.h | 14 + .../metal/scaled_dot_product_attention.cpp | 222 +++++++++ mlx/backend/no_metal/primitives.cpp | 1 + mlx/fast.cpp | 143 ++++++ mlx/fast.h | 11 + mlx/fast_primitives.h | 30 ++ python/src/fast.cpp | 40 ++ python/tests/test_fast_sdpa.py | 103 ++++ 12 files changed, 1018 insertions(+), 1 deletion(-) create mode 100644 mlx/backend/metal/kernels/scaled_dot_product_attention.metal create mode 100644 mlx/backend/metal/kernels/scaled_dot_product_attention_params.h create mode 100644 mlx/backend/metal/scaled_dot_product_attention.cpp create mode 100644 python/tests/test_fast_sdpa.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index bde148fe8..289c1a067 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,7 +13,7 @@ MLX was developed with contributions from the following individuals: - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - +- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index fe764c494..b8d3b26fe 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -29,6 +29,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 23979b8ac..b265babbe 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -25,6 +25,7 @@ set( "random" "rope" "scan" + "scaled_dot_product_attention" "softmax" "sort" "ternary" diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal new file mode 100644 index 000000000..fb9f0a111 --- /dev/null +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -0,0 +1,451 @@ +#include +#include + +#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" +using namespace metal; + +template +[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]], + const device T *K [[buffer(1)]], + const device T *V [[buffer(2)]], + const device uint64_t& L [[buffer(3)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], + device float* O_partials [[buffer(5)]], + device float* p_lse [[buffer(6)]], + device float* p_maxes [[buffer(7)]], + threadgroup T* threadgroup_block [[threadgroup(0)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr const size_t DK = 128; + constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; + constexpr const size_t THREADS_PER_SIMDGROUP = 32; + constexpr const uint iter_offset = NSIMDGROUPS * 4; + const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; + uint kv_head_offset_factor = tid.x; + if(is_gqa) { + int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; + kv_head_offset_factor = tid.x / q_kv_head_ratio; + } + constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; + constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); + constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; + constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS; + + threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; + #pragma clang loop unroll(full) + for(uint i = 0; i < 8; i++) { + smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // TODO: multiple query sequence length for speculative decoding + const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK); + + const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; + const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; + + const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; + const device T* baseQ = Q + tgroup_query_head_offset; + + device T4* simdgroupQueryData = (device T4*)baseQ; + + constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; + float threadAccum[ACCUM_PER_GROUP]; + + #pragma clang loop unroll(full) + for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) { + threadAccum[threadAccumIndex] = -INFINITY; + } + + uint KROW_ACCUM_INDEX = 0; + + const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; + const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; + const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); + + T4 thread_data_x4; + T4 thread_data_y4; + if(!LAST_TILE || LAST_TILE_ALIGNED) { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); + #pragma clang loop unroll(full) + for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseK + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; + } + } else { + thread_data_x4 = *(simdgroupQueryData + simd_lane_id); + const uint START_ROW = tid.y * TILE_SIZE_CONST; + const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset; + + for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) { + const uint KROW_OFFSET = KROW * DK; + const device T* baseKRow = baseKThisHead + KROW_OFFSET; + device T4* keysData = (device T4*)baseKRow; + thread_data_y4 = *(keysData + simd_lane_id); + T kq_scalar = dot(thread_data_x4, thread_data_y4); + threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); + KROW_ACCUM_INDEX++; + } + } + threadgroup float* smemP = (threadgroup float*)threadgroup_block; + + #pragma clang loop unroll(full) + for(size_t i = 0; i < P_VEC4; i++) { + thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]); + simdgroup_barrier(mem_flags::mem_none); + thread_data_y4 = simd_sum(thread_data_x4); + if(simd_lane_id == 0) { + const uint base_smem_p_offset = i * iter_offset + simd_group_id; + smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); + smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); + smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); + smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float groupMax; + float lse = 0.f; + + constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; + constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; + float4 pvals[ACCUM_ARRAY_LENGTH]; + + #pragma clang loop unroll(full) + for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) { + pvals[accum_array_iter] = float4(-INFINITY); + } + + if (TILE_SIZE_CONST == 64) { + threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; + float2 vals = smemPtrFlt2[simd_lane_id]; + vals *= params.INV_ALPHA; + float maxval = max(vals.x, vals.y); + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float2 expf_shifted = exp(vals - groupMax); + float sumExpLocal = expf_shifted.x + expf_shifted.y; + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + + lse = log(tgroupExpSum); + float2 local_p_hat = expf_shifted / tgroupExpSum; + pvals[0].x = local_p_hat.x; + pvals[0].y = local_p_hat.y; + smemPtrFlt2[simd_lane_id] = float2(0.f); + } + constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; + constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; + + if (TILE_SIZE_LARGER_THAN_64) { + float maxval = -INFINITY; + threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; + #pragma clang loop unroll(full) + for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { + float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; + vals *= params.INV_ALPHA; + pvals[i] = vals; + maxval = fmax3(vals.x, vals.y, maxval); + maxval = fmax3(vals.z, vals.w, maxval); + } + simdgroup_barrier(mem_flags::mem_none); + groupMax = simd_max(maxval); + + float sumExpLocal = 0.f; + #pragma clang loop unroll(full) + for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = exp(pvals[i] - groupMax); + sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; + } + simdgroup_barrier(mem_flags::mem_none); + float tgroupExpSum = simd_sum(sumExpLocal); + lse = log(tgroupExpSum); + #pragma clang loop unroll(full) + for(int i = 0; i < TILE_SIZE_ITERS_128; i++) { + pvals[i] = pvals[i] / tgroupExpSum; + smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); + } + } + + threadgroup T* smemV = (threadgroup T*)threadgroup_block; + + const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; + const size_t v_head_offset = kv_head_offset_factor * L * DK; + + const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; + const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; + device T* baseV = (device T*)V + v_offset; + + threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); + + if (!LAST_TILE || LAST_TILE_ALIGNED) { + #pragma clang loop unroll(full) + for(size_t col = 0; col < MATRIX_COLS; col++) { + uint matrix_load_loop_iter = 0; + constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; + + for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) { + simdgroup_matrix tmp; + ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); + simdgroup_load(tmp, baseV, DK, matrixOrigin, true); + const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); + const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); + matrix_load_loop_iter++; + }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + uint loop_iter = 0; + threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + + #pragma clang loop unroll(full) + for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + + T row_sum = simd_sum(val); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); + loop_iter++; + } + } + + if (TILE_SIZE_CONST > 64) { + constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128; + threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_iter = 0; + for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); + + T row_sum = 0.f; + for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[i]); + T val = dot(p_local, v_local); + row_sum += val; + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum); + loop_iter++; + } + } + } + } else { + const int32_t START_ROW = tid.y * TILE_SIZE_CONST; + const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; + const device T* baseVThisHead = V + v_batch_offset + v_head_offset; + constexpr const int ROWS_PER_ITER = 8; + #pragma clang loop unroll(full) + for(size_t col = 0; col < MATRIX_COLS; col++) { + uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; + int32_t tile_start; + for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { + simdgroup_matrix tmp; + ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); + simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); + const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); + constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; + simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false); + smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; + }; + + tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); + + const int32_t INT_L = int32_t(L); + for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) { + if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { + const uint elems_per_row_gmem = DK; + const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; + const uint row_index_v_gmem = row_index; + + const uint elems_per_row_smem = TILE_SIZE_CONST; + const uint col_index_v_smem = row_index % TILE_SIZE_CONST; + const uint row_index_v_smem = simd_lane_id; + + const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; + const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem; + T vdata = T(*(baseVThisHead + scalar_offset_gmem)); + smemV[scalar_offset_smem] = vdata; + smem_col_index += NSIMDGROUPS; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (TILE_SIZE_CONST == 64) { + T2 local_p_hat = T2(pvals[0].x, pvals[0].y); + threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + for(size_t smem_row_index = simd_group_id; + smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); + threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; + T2 v_local = *(smemV2 + simd_lane_id); + T val = dot(local_p_hat, v_local); + simdgroup_barrier(mem_flags::mem_none); + T row_sum = simd_sum(val); + oPartialSmem[smem_row_index] = float(row_sum); + } + } + + if (TILE_SIZE_CONST > 64) { + threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; + uint loop_count = 0; + for(size_t row_index = simd_group_id; + row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) { + T row_sum = 0.f; + for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) { + threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); + threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; + T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); + T4 p_local = T4(pvals[tile_iters]); + row_sum += dot(p_local, v_local); + + } + simdgroup_barrier(mem_flags::mem_none); + row_sum = simd_sum(row_sum); + oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum); + loop_count++; + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if(simd_group_id == 0) { + threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; + float4 vals = *(oPartialVec4 + simd_lane_id); + device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; + device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; + oPartialGmemVec4[simd_lane_id] = vals; + } + + if(simd_group_id == 0 && simd_lane_id == 0) { + const uint tileIndex = tid.y; + const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex; + p_lse[gmem_partial_scalar_offset] = lse; + p_maxes[gmem_partial_scalar_offset] = groupMax; + } +} + +#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \ +template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \ +[[kernel]] void fast_inference_sdpa_compute_partials_template( \ + const device itype *Q [[buffer(0)]], \ + const device itype *K [[buffer(1)]], \ + const device itype *V [[buffer(2)]], \ + const device uint64_t& L [[buffer(3)]], \ + const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ + device float* O_partials [[buffer(5)]], \ + device float* p_lse [[buffer(6)]], \ + device float* p_maxes [[buffer(7)]], \ + threadgroup itype *threadgroup_block [[threadgroup(0)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]]); + + +#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \ + instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \ + instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \ + +instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512); + +instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256); +instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512); + + +template +void fast_inference_sdpa_reduce_tiles_template( + const device float *O_partials [[buffer(0)]], + const device float *p_lse[[buffer(1)]], + const device float *p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device T* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + constexpr const int DK = 128; + const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; + const device float* p_lse_row = p_lse + offset_rows; + const device float* p_rowmax_row = p_maxes + offset_rows; + // reserve some number of registers. this constitutes an assumption on max value of KV TILES. + constexpr const uint8_t reserve = 128; + float p_lse_regs[reserve]; + float p_rowmax_regs[reserve]; + float weights[reserve]; + + float true_max = -INFINITY; + for(size_t i = 0; i < params.KV_TILES; i++) { + p_lse_regs[i] = float(*(p_lse_row + i)); + p_rowmax_regs[i] = float(*(p_rowmax_row + i)); + true_max = fmax(p_rowmax_regs[i], true_max); + weights[i] = exp(p_lse_regs[i]); + } + + float denom = 0.f; + for(size_t i = 0; i < params.KV_TILES; i++) { + weights[i] *= exp(p_rowmax_regs[i]-true_max); + denom += weights[i]; + } + + const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES; + + float o_value = 0.f; + for(size_t i = 0; i < params.KV_TILES; i++) { + float val = *(O_partials_with_offset + i * DK + lid.x); + o_value += val * weights[i] / denom; + } + device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; + O_gmem[lid.x] = T(o_value); + return; +} + + +kernel void fast_inference_sdpa_reduce_tiles_float( + const device float *O_partials [[buffer(0)]], + const device float *p_lse[[buffer(1)]], + const device float *p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device float* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) +{ + fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params, + O, tid, lid); +} + +kernel void fast_inference_sdpa_reduce_tiles_half( + const device float *O_partials [[buffer(0)]], + const device float *p_lse[[buffer(1)]], + const device float *p_maxes [[buffer(2)]], + const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], + device half* O [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) +{ + fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params, + O, tid, lid); +} diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h new file mode 100644 index 000000000..09b9defb8 --- /dev/null +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h @@ -0,0 +1,14 @@ +// +// scaled_dot_product_attention_params.h +// mlx + +#pragma once + +struct MLXScaledDotProductAttentionParams { + // Associated dimensions & transposition information + const uint QUERY_SEQUENCE_LENGTH = 1; + const uint N_Q_HEADS = 32; + const uint N_KV_HEADS = 32; + const uint KV_TILES = 1; + const float INV_ALPHA = 0.08838834764831843f; +}; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp new file mode 100644 index 000000000..28f2f162d --- /dev/null +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -0,0 +1,222 @@ +// +// scaled_dot_product_attention.cpp +// mlx + +#include +#include +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core::fast { + +namespace { + +void sdpa_metal( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const array& p_lse, + const array& p_rowmaxes, + const array& o_partial, + const uint heads, + const uint tile_size, + const uint n_tiles, + const float alpha, + array& out, + std::vector& temporaries) { + std::ostringstream kname_partials; + + kname_partials << "fast_inference_sdpa_compute_partials_"; + + std::ostringstream kname_reduce; + std::string delimiter = "_"; + kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter; + + for (const auto& arr : {k, v, out}) { + if (arr.dtype() != q.dtype()) { + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o"); + } + } + + if (q.dtype() == float32) { + kname_partials << "float" + delimiter; + kname_reduce << "float"; + } else if (q.dtype() == float16) { + kname_partials << "half" + delimiter; + kname_reduce << "half"; + } else { + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16."); + } + + std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter; + + uint nsimd = 8; + std::string kname_suffix_nsimdgroups = std::to_string(nsimd); + + // maximum number of splits == 128 at the moment (reserved tile registers in + // reduction kernel). this is arbitrary and could be changed in the shader. + + std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups; + kname_partials << kname_suffix; + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname_partials.str()); + compute_encoder->setComputePipelineState(kernel); + + constexpr const uint batch = 1; + MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch); + MTL::Size group_dims = MTL::Size(32, nsimd, 1); + + const uint64_t KV_sequence_length = k.shape(-2); + const uint query_sequence_length = q.shape(-2); + const uint n_q_heads = q.shape(1); + const uint n_kv_heads = k.shape(1); + + MLXScaledDotProductAttentionParams params{ + query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha}; + + set_array_buffer(compute_encoder, q, 0); + set_array_buffer(compute_encoder, k, 1); + set_array_buffer(compute_encoder, v, 2); + compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3); + compute_encoder->setBytes( + ¶ms, sizeof(MLXScaledDotProductAttentionParams), 4); + set_array_buffer(compute_encoder, o_partial, 5); + set_array_buffer(compute_encoder, p_lse, 6); + set_array_buffer(compute_encoder, p_rowmaxes, 7); + + constexpr const uint tgroupMemorySize = 32768; + compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + { + auto kernel_accum = d.get_kernel(kname_reduce.str()); + compute_encoder->setComputePipelineState(kernel_accum); + set_array_buffer(compute_encoder, o_partial, 0); + set_array_buffer(compute_encoder, p_lse, 1); + set_array_buffer(compute_encoder, p_rowmaxes, 2); + compute_encoder->setBytes( + ¶ms, sizeof(MLXScaledDotProductAttentionParams), 3); + set_array_buffer(compute_encoder, out, 4); + + MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch); + MTL::Size group_dims_reduce = MTL::Size(128, 1, 1); + + compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); + + d.get_command_buffer(s.index)->addCompletedHandler( + [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); + return; + } +} +} // namespace + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + array& out) { + assert(inputs.size() >= 3); + if (!is_floating_point(out.dtype())) { + throw std::runtime_error( + "[ScaledDotProductAttention] Does not yet support non-floating point types."); + } + + if (inputs.size() == 4) { + out = fallback_(inputs)[0]; + return; + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& v_pre = inputs[2]; + auto& o = out; + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector temporaries; + auto check_transpose = [&temporaries, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + temporaries.push_back(arr_copy); + size_t stx = arr.shape(-1); + return arr_copy; + } + }; + + auto q = check_transpose(q_pre); + auto k = check_transpose(k_pre); + auto v = check_transpose(v_pre); + + const int heads = q.shape(-3); + int tile_size = 64; + const int kv_seq_len = k.shape(-2); + if (kv_seq_len > 8000) { + tile_size = 128; + } + if (kv_seq_len > 16000) { + tile_size = 256; + } + if (kv_seq_len > 32000) { + tile_size = 512; + } + + const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size; + + array o_partials( + {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)}, + float32, + nullptr, + {}); + o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes())); + + array p_lse( + {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {}); + array p_rowmaxes( + {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {}); + p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes())); + p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes())); + + temporaries.push_back(p_lse); + temporaries.push_back(p_rowmaxes); + temporaries.push_back(o_partials); + + return sdpa_metal( + s, + d, + q, + k, + v, + p_lse, + p_rowmaxes, + o_partials, + heads, + tile_size, + n_tiles, + scale_, + out, + temporaries); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 4234eeb1c..bfe569041 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -99,6 +99,7 @@ NO_GPU(Transpose) namespace fast { NO_GPU_MULTI(RoPE) +NO_GPU(ScaledDotProductAttention) } // namespace fast } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp index ee28138f1..232de2f31 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -127,4 +127,147 @@ bool RoPE::is_equivalent(const Primitive& other) const { offset_ == a_other.offset_); } +/** Computes: O = softmax(Q @ K.T) @ V **/ +array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::optional& mask, + StreamOrDevice s) { + for (const auto& tensor : {queries, keys, values}) { + if (tensor.ndim() != 4) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] input with shape " + << tensor.shape() << " expected to be rank 4"; + throw std::invalid_argument(msg.str()); + } + } + + const size_t batch_dim = queries.shape(0); + for (const auto& tensor : {keys, values}) { + if (tensor.shape(0) != batch_dim) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] mismatching batch dimension for input with shape " + << tensor.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + + // Q, K must have matching last dims (d_k aka 'head_dim'); + if (queries.shape(-1) != keys.shape(-1)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + // K, V must have matching number of heads (n_kv_heads); + size_t n_q_heads = queries.shape(-3); + size_t n_kv_heads = keys.shape(-3); + + if (keys.shape(-3) != values.shape(-3)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads " + << keys.shape(-3) << " for values with n_heads " << values.shape(-3) + << "."; + throw std::invalid_argument(msg.str()); + } + + // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. + if (n_q_heads % n_kv_heads != 0) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads " + << n_q_heads << " for n_kv_heads " << n_kv_heads << "."; + throw std::invalid_argument(msg.str()); + } + + auto final_type = result_type({queries, keys, values}); + + auto q = astype(queries, final_type, s); + auto k = astype(keys, final_type, s); + auto v = astype(values, final_type, s); + + auto out_shape = + std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); + + /* generic implementation for use cases that Metal implementation does not + * support. For non-supported cases listed below, use MLX primitives: + * * CPU implementation + * * batch size > 1 + * * query sequence length > 1 + * * non-null mask + */ + bool needs_mask = mask.has_value(); + auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( + const std::vector& inputs) { + auto& q_tensor = inputs[0]; + auto& k_tensor = inputs[1]; + auto& v_tensor = inputs[2]; + auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s); + + auto tile_if_needs_repeat = + [n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array { + if (n_q_heads == n_kv_heads) + return arr; + int n_repeats = n_q_heads / n_kv_heads; + constexpr const int heads_axis = + 1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] -> + // [Batch, Heads, Sequence, Hidden] + auto ret = repeat(arr, n_repeats, heads_axis, s); + return ret; + }; + auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s); + auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s); + + // dim check on k, v; repeat if untiled, since naive matmul will have + // dim mismatch for GQA (MQA could make use of broadcast) + auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s); + auto s_tensor = matmul(q_scaled, k_transposed, s); + if (needs_mask) { + auto mask_tensor = inputs[3]; + s_tensor = add(s_tensor, mask_tensor, s); + } + auto p = astype( + softmax(astype(s_tensor, float32, s), std::vector{-1}, s), + final_type, + s); + auto out_tensor = matmul(p, v_tensor_tiled, s); + return std::vector{out_tensor}; + }; + + auto stream = to_stream(s); + + // current implementation use case: batch size 1, query sequence length 1, no + // mask. Likewise, requires head_dim == 128 + constexpr const int supported_head_dim = 128; + const size_t query_head_dim = q.shape(-1); + const size_t query_sequence_length = q.shape(2); + bool implementation_supports_use_case = batch_dim == 1 && + query_sequence_length == 1 && !mask.has_value() && + query_head_dim == supported_head_dim; + + if (stream.device == Device::gpu && implementation_supports_use_case) { + auto out = array( + out_shape, + final_type, + std::make_unique( + stream, fallback, scale, false), + {q, k, v}); + return out; + } + + if (mask.has_value()) { + return fallback({q, k, v, mask.value()})[0]; + } else { + return fallback({q, k, v})[0]; + } +} + +bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { + const ScaledDotProductAttention& a_other = + static_cast(other); + return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; +} + } // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h index 48ac90a5a..74b04886c 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/utils.h" namespace mlx::core::fast { @@ -15,4 +17,13 @@ array rope( int offset, StreamOrDevice s /* = {} */); +/** Computes: O = softmax(Q @ K.T) @ V **/ +array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::optional& mask = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 2b854960b..b581b09d9 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -65,4 +65,34 @@ class RoPE : public Custom { int offset_; }; +class ScaledDotProductAttention : public Custom { + public: + explicit ScaledDotProductAttention( + Stream stream, + std::function(std::vector)> fallback, + const float scale, + const bool needs_mask) + : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + outputs[0] = fallback_(inputs)[0]; + }; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override { + eval_gpu(inputs, outputs[0]); + }; + + void eval_gpu(const std::vector& inputs, array& out); + bool is_equivalent(const Primitive& other) const override; + + DEFINE_PRINT(ScaledDotProductAttention) + + private: + std::function(std::vector)> fallback_; + float scale_; + bool needs_mask_; +}; + } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 115ea37ec..1a667e8e4 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -56,4 +56,44 @@ void init_extensions(py::module_& parent_module) { Returns: array: The output array. )pbdoc"); + + m.def( + "scaled_dot_product_attention", + [](const array& q, + const array& k, + const array& v, + const float scale, + const std::optional& mask, + const StreamOrDevice& s) { + return fast::scaled_dot_product_attention(q, k, v, scale, mask, s); + }, + "q"_a, + "k"_a, + "v"_a, + py::kw_only(), + "scale"_a, + "mask"_a = none, + "stream"_a = none, + R"pbdoc( + scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array + + A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V. + Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + + This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + + Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). + Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + + Args: + q (array): Input query array. + k (array): Input keys array. + v (array): Input values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive mask to apply to the query-key scores. + + Returns: + array: The output array. + + )pbdoc"); } diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py new file mode 100644 index 000000000..8c8a599f4 --- /dev/null +++ b/python/tests/test_fast_sdpa.py @@ -0,0 +1,103 @@ +import math +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +# SDPA for MHA (n_heads == n_kv_heads) +def mlx_primitives_sdpa(q, k, v, scale): + p = (q * scale) @ k.transpose(0, 1, 3, 2) + scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) + return scores @ v + + +# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0) +def mlx_primitives_sdpa_with_gqa(q, k, v, scale): + + n_repeats = q.shape[1] // k.shape[1] + + # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py + n_heads = q.shape[1] + B = q.shape[0] + L = k.shape[2] + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * n_repeats, axis=2) + return a.reshape([B, n_heads, L, -1]) + + k, v = map(repeat, (k, v)) + + return mlx_primitives_sdpa(q, k, v, scale) + + +class TestFastInferenceSDPA(mlx_tests.MLXTestCase): + @property + def dtypes(self): + return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + + def test_fast_inference_sdpa(self): + + # Not yet supported: + # * K pre-transposed in kernel, V pre-transposed in kernel + np.random.seed(0) + L = 43 + R = 1 + Dk = 128 + scale = float(1.0 / np.sqrt(128.0)) + q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32) + k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) + v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) + + q_mlx = mx.array(q_npy) + k_mlx = mx.array(k_npy) + v_mlx = mx.array(v_npy) + + reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) + + o_mlx = mx.fast.scaled_dot_product_attention( + q_mlx, k_mlx, v_mlx, scale=scale, mask=None + ) + + self.assertListEqual(list(reference.shape), list(o_mlx.shape)) + self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) + + B = 1 + H = 32 + for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: + for DO_GQA in [0, 1]: + for DTYPE in [np.float32, np.half]: + n_kv_heads = 8 if DO_GQA else 32 + q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) + k_npy = np.random.normal( + 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) + ).astype(DTYPE) + v_npy = np.random.normal( + 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) + ).astype(DTYPE) + + q_mlx = mx.array(q_npy) + k_mlx = mx.array(k_npy) + v_mlx = mx.array(v_npy) + + reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) + o_mlx = mx.fast.scaled_dot_product_attention( + q_mlx, k_mlx, v_mlx, scale=scale + ) + + self.assertListEqual(list(reference.shape), list(o_mlx.shape)) + rtol = 1e-5 + atol = 1e-1 + + if SEQUENCE_LENGTH > 500: + rtol = 1e-2 + + if DTYPE == np.half: + rtol = 1e-2 + + self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) + + +if __name__ == "__main__": + unittest.main(failfast=True)