Fused attention for single query (#1497)

This commit is contained in:
Angelos Katharopoulos 2024-10-18 00:58:52 -07:00 committed by GitHub
parent 9dd72cd421
commit 50d8bed468
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 299 additions and 742 deletions

View File

@ -0,0 +1,49 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 1024
H = 32
H_k = 32 // 4
D = 128
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@ -30,8 +30,9 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
steel/defines.h steel/gemm/transforms.h steel/utils.h)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS
steel/defines.h

View File

@ -1,11 +1,11 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
using namespace mlx::steel;
@ -886,6 +886,9 @@ template <
}
}
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
@ -922,548 +925,28 @@ instantiate_fast_inference_self_attention_kernel(
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
template <
typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
[[kernel]] void sdpa_vector<type, head_dim>( \
const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) {
smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \
instantiate_sdpa_vector(type, 96) \
instantiate_sdpa_vector(type, 128)
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4;
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
// clang-format on

View File

@ -0,0 +1,115 @@
// Copyright © 2024 Apple Inc.
#include <metal_simdgroup>
using namespace metal;
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
const int stride = BN * D;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
// Move the pointers to the next kv
keys += stride;
values += stride;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@ -1,20 +1,13 @@
//
// scaled_dot_product_attention.cpp
// mlx
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
@ -26,8 +19,7 @@ void sdpa_full_self_attention_metal(
const array& k,
const array& v,
const float alpha,
array& out,
std::vector<array>& temporaries) {
array& out) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
@ -148,130 +140,58 @@ void sdpa_full_self_attention_metal(
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
void sdpa_metal(
void sdpa_vector(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& v,
const array& p_lse,
const array& p_rowmaxes,
const array& o_partial,
const uint heads,
const uint tile_size,
const uint n_tiles,
const float alpha,
array& out,
std::vector<array>& temporaries) {
std::ostringstream kname_partials;
float scale) {
// Set the kernel name
std::string kname;
kname.reserve(64);
kname += "sdpa_vector_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname_partials << "fast_inference_sdpa_compute_partials_";
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
std::ostringstream kname_reduce;
std::string delimiter = "_";
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
if (q.dtype() == float32) {
kname_partials << "float" + delimiter;
kname_reduce << "float";
} else if (q.dtype() == float16) {
kname_partials << "half" + delimiter;
kname_reduce << "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
uint nsimd = 8;
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
// maximum number of splits == 128 at the moment (reserved tile registers in
// reduction kernel). this is arbitrary and could be changed in the shader.
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
kname_partials << kname_suffix;
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_partials.str());
auto kernel = d.get_kernel(kname);
compute_encoder->setComputePipelineState(kernel);
constexpr const uint batch = 1;
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
MLXScaledDotProductAttentionParams params{
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
compute_encoder.set_input_array(q, 0);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 4);
compute_encoder.set_input_array(o_partial, 5);
compute_encoder.set_input_array(p_lse, 6);
compute_encoder.set_input_array(p_rowmaxes, 7);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
compute_encoder->setBytes(&N, sizeof(int), 5);
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
compute_encoder->setBytes(&scale, sizeof(float), 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
compute_encoder->setComputePipelineState(kernel_accum);
compute_encoder.set_input_array(o_partial, 0);
compute_encoder.set_input_array(p_lse, 1);
compute_encoder.set_input_array(p_rowmaxes, 2);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 3);
compute_encoder.set_output_array(out, 4);
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
}
} // namespace
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}
assert(inputs.size() == 3);
if (inputs.size() == 4) {
out = fallback_(inputs)[0];
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@ -279,84 +199,75 @@ void ScaledDotProductAttention::eval_gpu(
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> temporaries;
auto check_transpose = [&temporaries, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return arr;
} else {
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
temporaries.push_back(arr_copy);
size_t stx = arr.shape(-1);
copies.push_back(arr_copy);
return arr_copy;
} else {
return arr;
}
};
auto q = check_transpose(q_pre);
auto k = check_transpose(k_pre);
auto v = check_transpose(v_pre);
// Checks if arr is fully row contiguous
auto is_contiguous = [](const array& arr) {
return arr.flags().row_contiguous;
};
const int heads = q.shape(-3);
// Returns true if the array is row contiguous except the sequence length
// dimension that can be sliced but with step=1.
auto is_contiguous_except_seq_len = [](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return strides[3] == 1 && strides[2] == shape[3] &&
strides[0] == strides[1] * shape[1];
};
uint query_sequence_length = q.shape(-2);
if (query_sequence_length >= 16) {
return sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, out, temporaries);
}
int tile_size = 64;
const int kv_seq_len = k.shape(-2);
if (kv_seq_len > 8000) {
tile_size = 128;
}
if (kv_seq_len > 16000) {
tile_size = 256;
}
if (kv_seq_len > 32000) {
tile_size = 512;
// Checks that the last two dims are row contiguous.
auto is_matrix_contiguous = [](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return strides[3] == 1 && strides[2] == shape[3];
};
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
}
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
array o_partials(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
float32,
nullptr,
{});
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}
array p_lse(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
array p_rowmaxes(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
temporaries.push_back(p_lse);
temporaries.push_back(p_rowmaxes);
temporaries.push_back(o_partials);
return sdpa_metal(
s,
d,
q,
k,
v,
p_lse,
p_rowmaxes,
o_partials,
heads,
tile_size,
n_tiles,
scale_,
out,
temporaries);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
} // namespace mlx::core::fast

View File

@ -618,40 +618,38 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
const size_t value_head_dim = v.shape(-1);
const size_t query_head_dim = q.shape(-1);
const bool supported_head_dim =
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
const bool supported_head_dim_self_attn =
query_head_dim == 64 || query_head_dim == 128;
const size_t query_sequence_length = q.shape(2);
const bool supports_full_self_attention = query_sequence_length >= 16 &&
!mask.has_value() && supported_head_dim_self_attn &&
bool implementation_supports_use_case = query_head_dim == value_head_dim;
const bool sdpa_vector_supported_head_dim =
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
const bool sdpa_full_supported_head_dim =
query_head_dim == 64 || query_head_dim == 128;
const bool supports_sdpa_full = query_sequence_length >= threshold &&
!mask.has_value() && sdpa_full_supported_head_dim &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
// fast decoding gpu shader
bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 &&
!mask.has_value() && supported_head_dim && final_type != bfloat16 &&
const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu;
bool implementation_supports_use_case =
supports_sdpa || supports_full_self_attention;
// sdpa gpu shader is disabled except for memory efficient opt-in
const int seq_for_threshold = queries.shape(2);
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
implementation_supports_use_case &= use_memory_efficient_impl;
implementation_supports_use_case &=
supports_sdpa_full || supports_sdpa_vector;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out = array(
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
{q, k, v});
return out;
}
if (mask.has_value()) {