diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index bde148fe8..289c1a067 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -13,7 +13,7 @@ MLX was developed with contributions from the following individuals:
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
-
+- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt
index fe764c494..b8d3b26fe 100644
--- a/mlx/backend/metal/CMakeLists.txt
+++ b/mlx/backend/metal/CMakeLists.txt
@@ -29,6 +29,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 23979b8ac..b265babbe 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -25,6 +25,7 @@ set(
"random"
"rope"
"scan"
+ "scaled_dot_product_attention"
"softmax"
"sort"
"ternary"
diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal
new file mode 100644
index 000000000..fb9f0a111
--- /dev/null
+++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal
@@ -0,0 +1,451 @@
+#include
+#include
+
+#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
+using namespace metal;
+
+template
+[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
+ const device T *K [[buffer(1)]],
+ const device T *V [[buffer(2)]],
+ const device uint64_t& L [[buffer(3)]],
+ const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
+ device float* O_partials [[buffer(5)]],
+ device float* p_lse [[buffer(6)]],
+ device float* p_maxes [[buffer(7)]],
+ threadgroup T* threadgroup_block [[threadgroup(0)]],
+ uint simd_lane_id [[thread_index_in_simdgroup]],
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
+ uint3 tid [[threadgroup_position_in_grid]]) {
+ constexpr const size_t DK = 128;
+ constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
+ constexpr const size_t THREADS_PER_SIMDGROUP = 32;
+ constexpr const uint iter_offset = NSIMDGROUPS * 4;
+ const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
+ uint kv_head_offset_factor = tid.x;
+ if(is_gqa) {
+ int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
+ kv_head_offset_factor = tid.x / q_kv_head_ratio;
+ }
+ constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
+ constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
+ constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
+ constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
+
+ threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
+ #pragma clang loop unroll(full)
+ for(uint i = 0; i < 8; i++) {
+ smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ // TODO: multiple query sequence length for speculative decoding
+ const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
+
+ const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
+ const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
+ const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
+
+ const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
+ const device T* baseQ = Q + tgroup_query_head_offset;
+
+ device T4* simdgroupQueryData = (device T4*)baseQ;
+
+ constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
+ float threadAccum[ACCUM_PER_GROUP];
+
+ #pragma clang loop unroll(full)
+ for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
+ threadAccum[threadAccumIndex] = -INFINITY;
+ }
+
+ uint KROW_ACCUM_INDEX = 0;
+
+ const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
+ const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
+ const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
+
+ T4 thread_data_x4;
+ T4 thread_data_y4;
+ if(!LAST_TILE || LAST_TILE_ALIGNED) {
+ thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
+ #pragma clang loop unroll(full)
+ for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
+ const uint KROW_OFFSET = KROW * DK;
+ const device T* baseKRow = baseK + KROW_OFFSET;
+ device T4* keysData = (device T4*)baseKRow;
+ thread_data_y4 = *(keysData + simd_lane_id);
+ T kq_scalar = dot(thread_data_x4, thread_data_y4);
+ threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
+ KROW_ACCUM_INDEX++;
+ }
+ } else {
+ thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
+ const uint START_ROW = tid.y * TILE_SIZE_CONST;
+ const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
+
+ for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
+ const uint KROW_OFFSET = KROW * DK;
+ const device T* baseKRow = baseKThisHead + KROW_OFFSET;
+ device T4* keysData = (device T4*)baseKRow;
+ thread_data_y4 = *(keysData + simd_lane_id);
+ T kq_scalar = dot(thread_data_x4, thread_data_y4);
+ threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
+ KROW_ACCUM_INDEX++;
+ }
+ }
+ threadgroup float* smemP = (threadgroup float*)threadgroup_block;
+
+ #pragma clang loop unroll(full)
+ for(size_t i = 0; i < P_VEC4; i++) {
+ thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
+ simdgroup_barrier(mem_flags::mem_none);
+ thread_data_y4 = simd_sum(thread_data_x4);
+ if(simd_lane_id == 0) {
+ const uint base_smem_p_offset = i * iter_offset + simd_group_id;
+ smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
+ smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
+ smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
+ smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float groupMax;
+ float lse = 0.f;
+
+ constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
+ constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
+ float4 pvals[ACCUM_ARRAY_LENGTH];
+
+ #pragma clang loop unroll(full)
+ for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
+ pvals[accum_array_iter] = float4(-INFINITY);
+ }
+
+ if (TILE_SIZE_CONST == 64) {
+ threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
+ float2 vals = smemPtrFlt2[simd_lane_id];
+ vals *= params.INV_ALPHA;
+ float maxval = max(vals.x, vals.y);
+ simdgroup_barrier(mem_flags::mem_none);
+ groupMax = simd_max(maxval);
+
+ float2 expf_shifted = exp(vals - groupMax);
+ float sumExpLocal = expf_shifted.x + expf_shifted.y;
+ simdgroup_barrier(mem_flags::mem_none);
+ float tgroupExpSum = simd_sum(sumExpLocal);
+
+ lse = log(tgroupExpSum);
+ float2 local_p_hat = expf_shifted / tgroupExpSum;
+ pvals[0].x = local_p_hat.x;
+ pvals[0].y = local_p_hat.y;
+ smemPtrFlt2[simd_lane_id] = float2(0.f);
+ }
+ constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
+ constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
+
+ if (TILE_SIZE_LARGER_THAN_64) {
+ float maxval = -INFINITY;
+ threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
+ #pragma clang loop unroll(full)
+ for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
+ float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
+ vals *= params.INV_ALPHA;
+ pvals[i] = vals;
+ maxval = fmax3(vals.x, vals.y, maxval);
+ maxval = fmax3(vals.z, vals.w, maxval);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ groupMax = simd_max(maxval);
+
+ float sumExpLocal = 0.f;
+ #pragma clang loop unroll(full)
+ for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
+ pvals[i] = exp(pvals[i] - groupMax);
+ sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ float tgroupExpSum = simd_sum(sumExpLocal);
+ lse = log(tgroupExpSum);
+ #pragma clang loop unroll(full)
+ for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
+ pvals[i] = pvals[i] / tgroupExpSum;
+ smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
+ }
+ }
+
+ threadgroup T* smemV = (threadgroup T*)threadgroup_block;
+
+ const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
+ const size_t v_head_offset = kv_head_offset_factor * L * DK;
+
+ const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
+ const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
+ device T* baseV = (device T*)V + v_offset;
+
+ threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
+
+ if (!LAST_TILE || LAST_TILE_ALIGNED) {
+ #pragma clang loop unroll(full)
+ for(size_t col = 0; col < MATRIX_COLS; col++) {
+ uint matrix_load_loop_iter = 0;
+ constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
+
+ for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
+ simdgroup_matrix tmp;
+ ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
+ ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
+ simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
+ const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
+ const ulong elemsPerRowSmem = TILE_SIZE_CONST;
+ simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
+ matrix_load_loop_iter++;
+ };
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (TILE_SIZE_CONST == 64) {
+ T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
+ uint loop_iter = 0;
+ threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
+
+ #pragma clang loop unroll(full)
+ for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
+ threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
+ threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
+ T2 v_local = *(smemV2 + simd_lane_id);
+
+ T val = dot(local_p_hat, v_local);
+ simdgroup_barrier(mem_flags::mem_none);
+
+ T row_sum = simd_sum(val);
+ oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
+ loop_iter++;
+ }
+ }
+
+ if (TILE_SIZE_CONST > 64) {
+ constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
+ threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
+ uint loop_iter = 0;
+ for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
+ threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
+
+ T row_sum = 0.f;
+ for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
+ threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
+ T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
+ T4 p_local = T4(pvals[i]);
+ T val = dot(p_local, v_local);
+ row_sum += val;
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ row_sum = simd_sum(row_sum);
+ oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
+ loop_iter++;
+ }
+ }
+ }
+ } else {
+ const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
+ const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
+ const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
+ constexpr const int ROWS_PER_ITER = 8;
+ #pragma clang loop unroll(full)
+ for(size_t col = 0; col < MATRIX_COLS; col++) {
+ uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
+ int32_t tile_start;
+ for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
+ simdgroup_matrix tmp;
+ ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
+ simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
+ const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
+ constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
+ simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
+ smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
+ };
+
+ tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
+
+ const int32_t INT_L = int32_t(L);
+ for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
+ if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
+ const uint elems_per_row_gmem = DK;
+ const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
+ const uint row_index_v_gmem = row_index;
+
+ const uint elems_per_row_smem = TILE_SIZE_CONST;
+ const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
+ const uint row_index_v_smem = simd_lane_id;
+
+ const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
+ const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
+ T vdata = T(*(baseVThisHead + scalar_offset_gmem));
+ smemV[scalar_offset_smem] = vdata;
+ smem_col_index += NSIMDGROUPS;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (TILE_SIZE_CONST == 64) {
+ T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
+ threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
+ for(size_t smem_row_index = simd_group_id;
+ smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
+ threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
+ threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
+ T2 v_local = *(smemV2 + simd_lane_id);
+ T val = dot(local_p_hat, v_local);
+ simdgroup_barrier(mem_flags::mem_none);
+ T row_sum = simd_sum(val);
+ oPartialSmem[smem_row_index] = float(row_sum);
+ }
+ }
+
+ if (TILE_SIZE_CONST > 64) {
+ threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
+ uint loop_count = 0;
+ for(size_t row_index = simd_group_id;
+ row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
+ T row_sum = 0.f;
+ for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
+ threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
+ threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
+ T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
+ T4 p_local = T4(pvals[tile_iters]);
+ row_sum += dot(p_local, v_local);
+
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ row_sum = simd_sum(row_sum);
+ oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
+ loop_count++;
+ }
+ }
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if(simd_group_id == 0) {
+ threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
+ float4 vals = *(oPartialVec4 + simd_lane_id);
+ device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
+ device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
+ oPartialGmemVec4[simd_lane_id] = vals;
+ }
+
+ if(simd_group_id == 0 && simd_lane_id == 0) {
+ const uint tileIndex = tid.y;
+ const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
+ p_lse[gmem_partial_scalar_offset] = lse;
+ p_maxes[gmem_partial_scalar_offset] = groupMax;
+ }
+}
+
+#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
+template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
+[[kernel]] void fast_inference_sdpa_compute_partials_template( \
+ const device itype *Q [[buffer(0)]], \
+ const device itype *K [[buffer(1)]], \
+ const device itype *V [[buffer(2)]], \
+ const device uint64_t& L [[buffer(3)]], \
+ const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
+ device float* O_partials [[buffer(5)]], \
+ device float* p_lse [[buffer(6)]], \
+ device float* p_maxes [[buffer(7)]], \
+ threadgroup itype *threadgroup_block [[threadgroup(0)]], \
+ uint simd_lane_id [[thread_index_in_simdgroup]], \
+ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
+ uint3 tid [[threadgroup_position_in_grid]]);
+
+
+#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
+ instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
+ instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
+
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
+
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
+instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
+
+
+template
+void fast_inference_sdpa_reduce_tiles_template(
+ const device float *O_partials [[buffer(0)]],
+ const device float *p_lse[[buffer(1)]],
+ const device float *p_maxes [[buffer(2)]],
+ const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
+ device T* O [[buffer(4)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint3 lid [[thread_position_in_threadgroup]]) {
+
+ constexpr const int DK = 128;
+ const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
+ const device float* p_lse_row = p_lse + offset_rows;
+ const device float* p_rowmax_row = p_maxes + offset_rows;
+ // reserve some number of registers. this constitutes an assumption on max value of KV TILES.
+ constexpr const uint8_t reserve = 128;
+ float p_lse_regs[reserve];
+ float p_rowmax_regs[reserve];
+ float weights[reserve];
+
+ float true_max = -INFINITY;
+ for(size_t i = 0; i < params.KV_TILES; i++) {
+ p_lse_regs[i] = float(*(p_lse_row + i));
+ p_rowmax_regs[i] = float(*(p_rowmax_row + i));
+ true_max = fmax(p_rowmax_regs[i], true_max);
+ weights[i] = exp(p_lse_regs[i]);
+ }
+
+ float denom = 0.f;
+ for(size_t i = 0; i < params.KV_TILES; i++) {
+ weights[i] *= exp(p_rowmax_regs[i]-true_max);
+ denom += weights[i];
+ }
+
+ const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
+
+ float o_value = 0.f;
+ for(size_t i = 0; i < params.KV_TILES; i++) {
+ float val = *(O_partials_with_offset + i * DK + lid.x);
+ o_value += val * weights[i] / denom;
+ }
+ device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
+ O_gmem[lid.x] = T(o_value);
+ return;
+}
+
+
+kernel void fast_inference_sdpa_reduce_tiles_float(
+ const device float *O_partials [[buffer(0)]],
+ const device float *p_lse[[buffer(1)]],
+ const device float *p_maxes [[buffer(2)]],
+ const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
+ device float* O [[buffer(4)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint3 lid [[thread_position_in_threadgroup]])
+{
+ fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params,
+ O, tid, lid);
+}
+
+kernel void fast_inference_sdpa_reduce_tiles_half(
+ const device float *O_partials [[buffer(0)]],
+ const device float *p_lse[[buffer(1)]],
+ const device float *p_maxes [[buffer(2)]],
+ const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
+ device half* O [[buffer(4)]],
+ uint3 tid [[threadgroup_position_in_grid]],
+ uint3 lid [[thread_position_in_threadgroup]])
+{
+ fast_inference_sdpa_reduce_tiles_template(O_partials, p_lse, p_maxes, params,
+ O, tid, lid);
+}
diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h
new file mode 100644
index 000000000..09b9defb8
--- /dev/null
+++ b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h
@@ -0,0 +1,14 @@
+//
+// scaled_dot_product_attention_params.h
+// mlx
+
+#pragma once
+
+struct MLXScaledDotProductAttentionParams {
+ // Associated dimensions & transposition information
+ const uint QUERY_SEQUENCE_LENGTH = 1;
+ const uint N_Q_HEADS = 32;
+ const uint N_KV_HEADS = 32;
+ const uint KV_TILES = 1;
+ const float INV_ALPHA = 0.08838834764831843f;
+};
diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp
new file mode 100644
index 000000000..28f2f162d
--- /dev/null
+++ b/mlx/backend/metal/scaled_dot_product_attention.cpp
@@ -0,0 +1,222 @@
+//
+// scaled_dot_product_attention.cpp
+// mlx
+
+#include
+#include
+#include
+#include
+
+#include "mlx/backend/metal/copy.h"
+#include "mlx/backend/metal/device.h"
+#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
+#include "mlx/backend/metal/metal.h"
+#include "mlx/backend/metal/utils.h"
+#include "mlx/fast_primitives.h"
+#include "mlx/primitives.h"
+#include "mlx/utils.h"
+
+namespace mlx::core::fast {
+
+namespace {
+
+void sdpa_metal(
+ const Stream& s,
+ metal::Device& d,
+ const array& q,
+ const array& k,
+ const array& v,
+ const array& p_lse,
+ const array& p_rowmaxes,
+ const array& o_partial,
+ const uint heads,
+ const uint tile_size,
+ const uint n_tiles,
+ const float alpha,
+ array& out,
+ std::vector& temporaries) {
+ std::ostringstream kname_partials;
+
+ kname_partials << "fast_inference_sdpa_compute_partials_";
+
+ std::ostringstream kname_reduce;
+ std::string delimiter = "_";
+ kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
+
+ for (const auto& arr : {k, v, out}) {
+ if (arr.dtype() != q.dtype()) {
+ throw std::runtime_error(
+ "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
+ }
+ }
+
+ if (q.dtype() == float32) {
+ kname_partials << "float" + delimiter;
+ kname_reduce << "float";
+ } else if (q.dtype() == float16) {
+ kname_partials << "half" + delimiter;
+ kname_reduce << "half";
+ } else {
+ throw std::runtime_error(
+ "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
+ }
+
+ std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
+
+ uint nsimd = 8;
+ std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
+
+ // maximum number of splits == 128 at the moment (reserved tile registers in
+ // reduction kernel). this is arbitrary and could be changed in the shader.
+
+ std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
+ kname_partials << kname_suffix;
+ auto compute_encoder = d.get_command_encoder(s.index);
+ auto kernel = d.get_kernel(kname_partials.str());
+ compute_encoder->setComputePipelineState(kernel);
+
+ constexpr const uint batch = 1;
+ MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
+ MTL::Size group_dims = MTL::Size(32, nsimd, 1);
+
+ const uint64_t KV_sequence_length = k.shape(-2);
+ const uint query_sequence_length = q.shape(-2);
+ const uint n_q_heads = q.shape(1);
+ const uint n_kv_heads = k.shape(1);
+
+ MLXScaledDotProductAttentionParams params{
+ query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
+
+ set_array_buffer(compute_encoder, q, 0);
+ set_array_buffer(compute_encoder, k, 1);
+ set_array_buffer(compute_encoder, v, 2);
+ compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
+ compute_encoder->setBytes(
+ ¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
+ set_array_buffer(compute_encoder, o_partial, 5);
+ set_array_buffer(compute_encoder, p_lse, 6);
+ set_array_buffer(compute_encoder, p_rowmaxes, 7);
+
+ constexpr const uint tgroupMemorySize = 32768;
+ compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
+ compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
+
+ {
+ auto kernel_accum = d.get_kernel(kname_reduce.str());
+ compute_encoder->setComputePipelineState(kernel_accum);
+ set_array_buffer(compute_encoder, o_partial, 0);
+ set_array_buffer(compute_encoder, p_lse, 1);
+ set_array_buffer(compute_encoder, p_rowmaxes, 2);
+ compute_encoder->setBytes(
+ ¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
+ set_array_buffer(compute_encoder, out, 4);
+
+ MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
+ MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
+
+ compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
+
+ d.get_command_buffer(s.index)->addCompletedHandler(
+ [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
+ return;
+ }
+}
+} // namespace
+
+void ScaledDotProductAttention::eval_gpu(
+ const std::vector& inputs,
+ array& out) {
+ assert(inputs.size() >= 3);
+ if (!is_floating_point(out.dtype())) {
+ throw std::runtime_error(
+ "[ScaledDotProductAttention] Does not yet support non-floating point types.");
+ }
+
+ if (inputs.size() == 4) {
+ out = fallback_(inputs)[0];
+ return;
+ }
+
+ out.set_data(allocator::malloc_or_wait(out.nbytes()));
+ auto& s = stream();
+ auto& d = metal::device(s.device);
+
+ auto& q_pre = inputs[0];
+ auto& k_pre = inputs[1];
+ auto& v_pre = inputs[2];
+ auto& o = out;
+ /////////////////////////////////////////////////////////////////////////////
+ // Init checks and prep
+
+ // Keep a vector with copies to be cleared in the completed buffer to release
+ // the arrays
+ std::vector temporaries;
+ auto check_transpose = [&temporaries, &s](const array& arr) {
+ auto stx = arr.strides()[arr.ndim() - 2];
+ auto sty = arr.strides()[arr.ndim() - 1];
+ if (stx == arr.shape(-1) && sty == 1) {
+ return arr;
+ } else {
+ array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
+ copy_gpu(arr, arr_copy, CopyType::General, s);
+ temporaries.push_back(arr_copy);
+ size_t stx = arr.shape(-1);
+ return arr_copy;
+ }
+ };
+
+ auto q = check_transpose(q_pre);
+ auto k = check_transpose(k_pre);
+ auto v = check_transpose(v_pre);
+
+ const int heads = q.shape(-3);
+ int tile_size = 64;
+ const int kv_seq_len = k.shape(-2);
+ if (kv_seq_len > 8000) {
+ tile_size = 128;
+ }
+ if (kv_seq_len > 16000) {
+ tile_size = 256;
+ }
+ if (kv_seq_len > 32000) {
+ tile_size = 512;
+ }
+
+ const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
+
+ array o_partials(
+ {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
+ float32,
+ nullptr,
+ {});
+ o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
+
+ array p_lse(
+ {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
+ array p_rowmaxes(
+ {q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
+ p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
+ p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
+
+ temporaries.push_back(p_lse);
+ temporaries.push_back(p_rowmaxes);
+ temporaries.push_back(o_partials);
+
+ return sdpa_metal(
+ s,
+ d,
+ q,
+ k,
+ v,
+ p_lse,
+ p_rowmaxes,
+ o_partials,
+ heads,
+ tile_size,
+ n_tiles,
+ scale_,
+ out,
+ temporaries);
+}
+
+} // namespace mlx::core::fast
diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp
index 4234eeb1c..bfe569041 100644
--- a/mlx/backend/no_metal/primitives.cpp
+++ b/mlx/backend/no_metal/primitives.cpp
@@ -99,6 +99,7 @@ NO_GPU(Transpose)
namespace fast {
NO_GPU_MULTI(RoPE)
+NO_GPU(ScaledDotProductAttention)
} // namespace fast
} // namespace mlx::core
diff --git a/mlx/fast.cpp b/mlx/fast.cpp
index ee28138f1..232de2f31 100644
--- a/mlx/fast.cpp
+++ b/mlx/fast.cpp
@@ -127,4 +127,147 @@ bool RoPE::is_equivalent(const Primitive& other) const {
offset_ == a_other.offset_);
}
+/** Computes: O = softmax(Q @ K.T) @ V **/
+array scaled_dot_product_attention(
+ const array& queries,
+ const array& keys,
+ const array& values,
+ const float scale,
+ const std::optional& mask,
+ StreamOrDevice s) {
+ for (const auto& tensor : {queries, keys, values}) {
+ if (tensor.ndim() != 4) {
+ std::ostringstream msg;
+ msg << "[scaled_dot_product_attention] input with shape "
+ << tensor.shape() << " expected to be rank 4";
+ throw std::invalid_argument(msg.str());
+ }
+ }
+
+ const size_t batch_dim = queries.shape(0);
+ for (const auto& tensor : {keys, values}) {
+ if (tensor.shape(0) != batch_dim) {
+ std::ostringstream msg;
+ msg << "[scaled_dot_product_attention] mismatching batch dimension for input with shape "
+ << tensor.shape() << ".";
+ throw std::invalid_argument(msg.str());
+ }
+ }
+
+ // Q, K must have matching last dims (d_k aka 'head_dim');
+ if (queries.shape(-1) != keys.shape(-1)) {
+ std::ostringstream msg;
+ msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
+ << queries.shape() << " for keys shape " << keys.shape() << ".";
+ throw std::invalid_argument(msg.str());
+ }
+
+ // K, V must have matching number of heads (n_kv_heads);
+ size_t n_q_heads = queries.shape(-3);
+ size_t n_kv_heads = keys.shape(-3);
+
+ if (keys.shape(-3) != values.shape(-3)) {
+ std::ostringstream msg;
+ msg << "[scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads "
+ << keys.shape(-3) << " for values with n_heads " << values.shape(-3)
+ << ".";
+ throw std::invalid_argument(msg.str());
+ }
+
+ // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
+ if (n_q_heads % n_kv_heads != 0) {
+ std::ostringstream msg;
+ msg << "[scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads "
+ << n_q_heads << " for n_kv_heads " << n_kv_heads << ".";
+ throw std::invalid_argument(msg.str());
+ }
+
+ auto final_type = result_type({queries, keys, values});
+
+ auto q = astype(queries, final_type, s);
+ auto k = astype(keys, final_type, s);
+ auto v = astype(values, final_type, s);
+
+ auto out_shape =
+ std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
+
+ /* generic implementation for use cases that Metal implementation does not
+ * support. For non-supported cases listed below, use MLX primitives:
+ * * CPU implementation
+ * * batch size > 1
+ * * query sequence length > 1
+ * * non-null mask
+ */
+ bool needs_mask = mask.has_value();
+ auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
+ const std::vector& inputs) {
+ auto& q_tensor = inputs[0];
+ auto& k_tensor = inputs[1];
+ auto& v_tensor = inputs[2];
+ auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s);
+
+ auto tile_if_needs_repeat =
+ [n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array {
+ if (n_q_heads == n_kv_heads)
+ return arr;
+ int n_repeats = n_q_heads / n_kv_heads;
+ constexpr const int heads_axis =
+ 1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] ->
+ // [Batch, Heads, Sequence, Hidden]
+ auto ret = repeat(arr, n_repeats, heads_axis, s);
+ return ret;
+ };
+ auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s);
+ auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s);
+
+ // dim check on k, v; repeat if untiled, since naive matmul will have
+ // dim mismatch for GQA (MQA could make use of broadcast)
+ auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s);
+ auto s_tensor = matmul(q_scaled, k_transposed, s);
+ if (needs_mask) {
+ auto mask_tensor = inputs[3];
+ s_tensor = add(s_tensor, mask_tensor, s);
+ }
+ auto p = astype(
+ softmax(astype(s_tensor, float32, s), std::vector{-1}, s),
+ final_type,
+ s);
+ auto out_tensor = matmul(p, v_tensor_tiled, s);
+ return std::vector{out_tensor};
+ };
+
+ auto stream = to_stream(s);
+
+ // current implementation use case: batch size 1, query sequence length 1, no
+ // mask. Likewise, requires head_dim == 128
+ constexpr const int supported_head_dim = 128;
+ const size_t query_head_dim = q.shape(-1);
+ const size_t query_sequence_length = q.shape(2);
+ bool implementation_supports_use_case = batch_dim == 1 &&
+ query_sequence_length == 1 && !mask.has_value() &&
+ query_head_dim == supported_head_dim;
+
+ if (stream.device == Device::gpu && implementation_supports_use_case) {
+ auto out = array(
+ out_shape,
+ final_type,
+ std::make_unique(
+ stream, fallback, scale, false),
+ {q, k, v});
+ return out;
+ }
+
+ if (mask.has_value()) {
+ return fallback({q, k, v, mask.value()})[0];
+ } else {
+ return fallback({q, k, v})[0];
+ }
+}
+
+bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
+ const ScaledDotProductAttention& a_other =
+ static_cast(other);
+ return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
+}
+
} // namespace mlx::core::fast
diff --git a/mlx/fast.h b/mlx/fast.h
index 48ac90a5a..74b04886c 100644
--- a/mlx/fast.h
+++ b/mlx/fast.h
@@ -2,6 +2,8 @@
#pragma once
+#include
+
#include "mlx/utils.h"
namespace mlx::core::fast {
@@ -15,4 +17,13 @@ array rope(
int offset,
StreamOrDevice s /* = {} */);
+/** Computes: O = softmax(Q @ K.T) @ V **/
+array scaled_dot_product_attention(
+ const array& queries,
+ const array& keys,
+ const array& values,
+ const float scale,
+ const std::optional& mask = std::nullopt,
+ StreamOrDevice s = {});
+
} // namespace mlx::core::fast
diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h
index 2b854960b..b581b09d9 100644
--- a/mlx/fast_primitives.h
+++ b/mlx/fast_primitives.h
@@ -65,4 +65,34 @@ class RoPE : public Custom {
int offset_;
};
+class ScaledDotProductAttention : public Custom {
+ public:
+ explicit ScaledDotProductAttention(
+ Stream stream,
+ std::function(std::vector)> fallback,
+ const float scale,
+ const bool needs_mask)
+ : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
+
+ void eval_cpu(const std::vector& inputs, std::vector& outputs)
+ override {
+ outputs[0] = fallback_(inputs)[0];
+ };
+
+ void eval_gpu(const std::vector& inputs, std::vector& outputs)
+ override {
+ eval_gpu(inputs, outputs[0]);
+ };
+
+ void eval_gpu(const std::vector& inputs, array& out);
+ bool is_equivalent(const Primitive& other) const override;
+
+ DEFINE_PRINT(ScaledDotProductAttention)
+
+ private:
+ std::function(std::vector)> fallback_;
+ float scale_;
+ bool needs_mask_;
+};
+
} // namespace mlx::core::fast
diff --git a/python/src/fast.cpp b/python/src/fast.cpp
index 115ea37ec..1a667e8e4 100644
--- a/python/src/fast.cpp
+++ b/python/src/fast.cpp
@@ -56,4 +56,44 @@ void init_extensions(py::module_& parent_module) {
Returns:
array: The output array.
)pbdoc");
+
+ m.def(
+ "scaled_dot_product_attention",
+ [](const array& q,
+ const array& k,
+ const array& v,
+ const float scale,
+ const std::optional& mask,
+ const StreamOrDevice& s) {
+ return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
+ },
+ "q"_a,
+ "k"_a,
+ "v"_a,
+ py::kw_only(),
+ "scale"_a,
+ "mask"_a = none,
+ "stream"_a = none,
+ R"pbdoc(
+ scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array
+
+ A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
+ Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150).
+
+ This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations.
+
+ Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32).
+ Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array.
+
+ Args:
+ q (array): Input query array.
+ k (array): Input keys array.
+ v (array): Input values array.
+ scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
+ mask (array, optional): An additive mask to apply to the query-key scores.
+
+ Returns:
+ array: The output array.
+
+ )pbdoc");
}
diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py
new file mode 100644
index 000000000..8c8a599f4
--- /dev/null
+++ b/python/tests/test_fast_sdpa.py
@@ -0,0 +1,103 @@
+import math
+import unittest
+
+import mlx.core as mx
+import mlx_tests
+import numpy as np
+
+
+# SDPA for MHA (n_heads == n_kv_heads)
+def mlx_primitives_sdpa(q, k, v, scale):
+ p = (q * scale) @ k.transpose(0, 1, 3, 2)
+ scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
+ return scores @ v
+
+
+# SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0)
+def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
+
+ n_repeats = q.shape[1] // k.shape[1]
+
+ # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py
+ n_heads = q.shape[1]
+ B = q.shape[0]
+ L = k.shape[2]
+
+ def repeat(a):
+ a = mx.concatenate([mx.expand_dims(a, 2)] * n_repeats, axis=2)
+ return a.reshape([B, n_heads, L, -1])
+
+ k, v = map(repeat, (k, v))
+
+ return mlx_primitives_sdpa(q, k, v, scale)
+
+
+class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
+ @property
+ def dtypes(self):
+ return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
+
+ def test_fast_inference_sdpa(self):
+
+ # Not yet supported:
+ # * K pre-transposed in kernel, V pre-transposed in kernel
+ np.random.seed(0)
+ L = 43
+ R = 1
+ Dk = 128
+ scale = float(1.0 / np.sqrt(128.0))
+ q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32)
+ k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
+ v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
+
+ q_mlx = mx.array(q_npy)
+ k_mlx = mx.array(k_npy)
+ v_mlx = mx.array(v_npy)
+
+ reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
+
+ o_mlx = mx.fast.scaled_dot_product_attention(
+ q_mlx, k_mlx, v_mlx, scale=scale, mask=None
+ )
+
+ self.assertListEqual(list(reference.shape), list(o_mlx.shape))
+ self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
+
+ B = 1
+ H = 32
+ for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
+ for DO_GQA in [0, 1]:
+ for DTYPE in [np.float32, np.half]:
+ n_kv_heads = 8 if DO_GQA else 32
+ q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
+ k_npy = np.random.normal(
+ 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
+ ).astype(DTYPE)
+ v_npy = np.random.normal(
+ 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
+ ).astype(DTYPE)
+
+ q_mlx = mx.array(q_npy)
+ k_mlx = mx.array(k_npy)
+ v_mlx = mx.array(v_npy)
+
+ reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
+ o_mlx = mx.fast.scaled_dot_product_attention(
+ q_mlx, k_mlx, v_mlx, scale=scale
+ )
+
+ self.assertListEqual(list(reference.shape), list(o_mlx.shape))
+ rtol = 1e-5
+ atol = 1e-1
+
+ if SEQUENCE_LENGTH > 500:
+ rtol = 1e-2
+
+ if DTYPE == np.half:
+ rtol = 1e-2
+
+ self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
+
+
+if __name__ == "__main__":
+ unittest.main(failfast=True)