mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fused attention for single query (#1497)
This commit is contained in:
parent
9dd72cd421
commit
50d8bed468
49
benchmarks/python/sdpa_vector_bench.py
Normal file
49
benchmarks/python/sdpa_vector_bench.py
Normal file
@ -0,0 +1,49 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
@ -30,8 +30,9 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
|
@ -1,11 +1,11 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
using namespace metal;
|
||||
|
||||
using namespace mlx::steel;
|
||||
@ -886,6 +886,9 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// SDPA full instantiations
|
||||
#define instantiate_fast_inference_self_attention_kernel( \
|
||||
itype, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
|
||||
@ -922,548 +925,28 @@ instantiate_fast_inference_self_attention_kernel(
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename T2,
|
||||
typename T4,
|
||||
uint16_t TILE_SIZE_CONST,
|
||||
uint16_t NSIMDGROUPS>
|
||||
[[kernel]] void fast_inference_sdpa_compute_partials_template(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
const device uint64_t& L [[buffer(3)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
|
||||
device float* O_partials [[buffer(5)]],
|
||||
device float* p_lse [[buffer(6)]],
|
||||
device float* p_maxes [[buffer(7)]],
|
||||
threadgroup T* threadgroup_block [[threadgroup(0)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr const size_t DK = 128;
|
||||
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
|
||||
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
|
||||
constexpr const uint iter_offset = NSIMDGROUPS * 4;
|
||||
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
|
||||
uint kv_head_offset_factor = tid.x;
|
||||
if (is_gqa) {
|
||||
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
|
||||
kv_head_offset_factor = tid.x / q_kv_head_ratio;
|
||||
}
|
||||
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
|
||||
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
|
||||
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
|
||||
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
|
||||
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
|
||||
NSIMDGROUPS;
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
|
||||
[[kernel]] void sdpa_vector<type, head_dim>( \
|
||||
const device type* queries [[buffer(0)]], \
|
||||
const device type* keys [[buffer(1)]], \
|
||||
const device type* values [[buffer(2)]], \
|
||||
device type* out [[buffer(3)]], \
|
||||
const constant int& gqa_factor, \
|
||||
const constant int& N, \
|
||||
const constant size_t& k_stride, \
|
||||
const constant float& scale, \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint i = 0; i < 8; i++) {
|
||||
smemFlush
|
||||
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
|
||||
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// TODO: multiple query sequence length for speculative decoding
|
||||
const uint tgroup_query_head_offset =
|
||||
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
|
||||
#define instantiate_sdpa_vector_heads(type) \
|
||||
instantiate_sdpa_vector(type, 64) \
|
||||
instantiate_sdpa_vector(type, 96) \
|
||||
instantiate_sdpa_vector(type, 128)
|
||||
|
||||
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
|
||||
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
|
||||
|
||||
const device T* baseK =
|
||||
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
|
||||
const device T* baseQ = Q + tgroup_query_head_offset;
|
||||
|
||||
device T4* simdgroupQueryData = (device T4*)baseQ;
|
||||
|
||||
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
|
||||
float threadAccum[ACCUM_PER_GROUP];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
|
||||
threadAccumIndex++) {
|
||||
threadAccum[threadAccumIndex] = -INFINITY;
|
||||
}
|
||||
|
||||
uint KROW_ACCUM_INDEX = 0;
|
||||
|
||||
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
|
||||
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
|
||||
const bool LAST_TILE_ALIGNED =
|
||||
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
|
||||
|
||||
T4 thread_data_x4;
|
||||
T4 thread_data_y4;
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
|
||||
KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseK + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
} else {
|
||||
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
|
||||
const uint START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const device T* baseKThisHead =
|
||||
K + tgroup_k_batch_offset + tgroup_k_head_offset;
|
||||
|
||||
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
|
||||
KROW += NSIMDGROUPS) {
|
||||
const uint KROW_OFFSET = KROW * DK;
|
||||
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
|
||||
device T4* keysData = (device T4*)baseKRow;
|
||||
thread_data_y4 = *(keysData + simd_lane_id);
|
||||
T kq_scalar = dot(thread_data_x4, thread_data_y4);
|
||||
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
|
||||
KROW_ACCUM_INDEX++;
|
||||
}
|
||||
}
|
||||
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t i = 0; i < P_VEC4; i++) {
|
||||
thread_data_x4 =
|
||||
T4(threadAccum[4 * i],
|
||||
threadAccum[4 * i + 1],
|
||||
threadAccum[4 * i + 2],
|
||||
threadAccum[4 * i + 3]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
thread_data_y4 = simd_sum(thread_data_x4);
|
||||
if (simd_lane_id == 0) {
|
||||
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
|
||||
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float groupMax;
|
||||
float lse = 0.f;
|
||||
|
||||
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
|
||||
constexpr const size_t ACCUM_ARRAY_LENGTH =
|
||||
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
|
||||
float4 pvals[ACCUM_ARRAY_LENGTH];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
|
||||
accum_array_iter++) {
|
||||
pvals[accum_array_iter] = float4(-INFINITY);
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
|
||||
float2 vals = smemPtrFlt2[simd_lane_id];
|
||||
vals *= params.INV_ALPHA;
|
||||
float maxval = max(vals.x, vals.y);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float2 expf_shifted = exp(vals - groupMax);
|
||||
float sumExpLocal = expf_shifted.x + expf_shifted.y;
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
|
||||
lse = log(tgroupExpSum);
|
||||
float2 local_p_hat = expf_shifted / tgroupExpSum;
|
||||
pvals[0].x = local_p_hat.x;
|
||||
pvals[0].y = local_p_hat.y;
|
||||
smemPtrFlt2[simd_lane_id] = float2(0.f);
|
||||
}
|
||||
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
|
||||
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
|
||||
|
||||
if (TILE_SIZE_LARGER_THAN_64) {
|
||||
float maxval = -INFINITY;
|
||||
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
|
||||
vals *= params.INV_ALPHA;
|
||||
pvals[i] = vals;
|
||||
maxval = fmax3(vals.x, vals.y, maxval);
|
||||
maxval = fmax3(vals.z, vals.w, maxval);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
groupMax = simd_max(maxval);
|
||||
|
||||
float sumExpLocal = 0.f;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = exp(pvals[i] - groupMax);
|
||||
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
float tgroupExpSum = simd_sum(sumExpLocal);
|
||||
lse = log(tgroupExpSum);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
|
||||
pvals[i] = pvals[i] / tgroupExpSum;
|
||||
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
|
||||
|
||||
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
|
||||
const size_t v_head_offset = kv_head_offset_factor * L * DK;
|
||||
|
||||
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
|
||||
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
|
||||
device T* baseV = (device T*)V + v_offset;
|
||||
|
||||
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
|
||||
|
||||
if (!LAST_TILE || LAST_TILE_ALIGNED) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint matrix_load_loop_iter = 0;
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
|
||||
|
||||
for (size_t tile_start = simd_group_id;
|
||||
tile_start < TILE_SIZE_CONST_DIV_8;
|
||||
tile_start += NSIMDGROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong simdgroup_matrix_offset =
|
||||
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
|
||||
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
ulong2 matrixOrigin =
|
||||
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
|
||||
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
|
||||
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
|
||||
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
|
||||
matrix_load_loop_iter++;
|
||||
};
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
uint loop_iter = 0;
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||
float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
|
||||
(TILE_SIZE_CONST + 1) / 128;
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_iter = 0;
|
||||
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
row += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
|
||||
|
||||
T row_sum = 0.f;
|
||||
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[i]);
|
||||
T val = dot(p_local, v_local);
|
||||
row_sum += val;
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
|
||||
float(row_sum);
|
||||
loop_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
|
||||
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
|
||||
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
|
||||
constexpr const int ROWS_PER_ITER = 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for (size_t col = 0; col < MATRIX_COLS; col++) {
|
||||
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
int32_t tile_start;
|
||||
for (tile_start =
|
||||
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
tile_start < MAX_START_ROW;
|
||||
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
simdgroup_matrix<T, 8, 8> tmp;
|
||||
ulong2 matrixOrigin =
|
||||
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
|
||||
simdgroup_load(
|
||||
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
|
||||
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
|
||||
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
|
||||
simdgroup_store(
|
||||
tmp,
|
||||
smemV,
|
||||
elemsPerRowSmem,
|
||||
matrixOriginSmem,
|
||||
/* transpose */ false);
|
||||
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
|
||||
};
|
||||
|
||||
tile_start =
|
||||
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
|
||||
|
||||
const int32_t INT_L = int32_t(L);
|
||||
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
|
||||
row_index += NSIMDGROUPS) {
|
||||
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
|
||||
const uint elems_per_row_gmem = DK;
|
||||
const uint col_index_v_gmem =
|
||||
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
|
||||
const uint row_index_v_gmem = row_index;
|
||||
|
||||
const uint elems_per_row_smem = TILE_SIZE_CONST;
|
||||
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
|
||||
const uint row_index_v_smem = simd_lane_id;
|
||||
|
||||
const uint scalar_offset_gmem =
|
||||
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
|
||||
const uint scalar_offset_smem =
|
||||
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
|
||||
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
|
||||
smemV[scalar_offset_smem] = vdata;
|
||||
smem_col_index += NSIMDGROUPS;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (TILE_SIZE_CONST == 64) {
|
||||
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
for (size_t smem_row_index = simd_group_id;
|
||||
smem_row_index < ROWS_PER_ITER;
|
||||
smem_row_index += NSIMDGROUPS) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
|
||||
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
|
||||
T2 v_local = *(smemV2 + simd_lane_id);
|
||||
T val = dot(local_p_hat, v_local);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
T row_sum = simd_sum(val);
|
||||
oPartialSmem[smem_row_index] = float(row_sum);
|
||||
}
|
||||
}
|
||||
|
||||
if (TILE_SIZE_CONST > 64) {
|
||||
threadgroup float* oPartialSmem =
|
||||
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
|
||||
uint loop_count = 0;
|
||||
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
|
||||
row_index += NSIMDGROUPS) {
|
||||
T row_sum = 0.f;
|
||||
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
|
||||
tile_iters++) {
|
||||
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
|
||||
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
|
||||
T4 v_local =
|
||||
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
|
||||
T4 p_local = T4(pvals[tile_iters]);
|
||||
row_sum += dot(p_local, v_local);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
row_sum = simd_sum(row_sum);
|
||||
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
|
||||
float(row_sum);
|
||||
loop_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (simd_group_id == 0) {
|
||||
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
|
||||
float4 vals = *(oPartialVec4 + simd_lane_id);
|
||||
device float* oPartialGmem =
|
||||
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
|
||||
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
|
||||
oPartialGmemVec4[simd_lane_id] = vals;
|
||||
}
|
||||
|
||||
if (simd_group_id == 0 && simd_lane_id == 0) {
|
||||
const uint tileIndex = tid.y;
|
||||
const uint gmem_partial_scalar_offset =
|
||||
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
|
||||
tileIndex;
|
||||
p_lse[gmem_partial_scalar_offset] = lse;
|
||||
p_maxes[gmem_partial_scalar_offset] = groupMax;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, nsimdgroups) \
|
||||
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
|
||||
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
|
||||
fast_inference_sdpa_compute_partials_template< \
|
||||
itype, \
|
||||
itype2, \
|
||||
itype4, \
|
||||
tile_size, \
|
||||
nsimdgroups>( \
|
||||
const device itype* Q [[buffer(0)]], \
|
||||
const device itype* K [[buffer(1)]], \
|
||||
const device itype* V [[buffer(2)]], \
|
||||
const device uint64_t& L [[buffer(3)]], \
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
|
||||
device float* O_partials [[buffer(5)]], \
|
||||
device float* p_lse [[buffer(6)]], \
|
||||
device float* p_maxes [[buffer(7)]], \
|
||||
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
|
||||
itype, itype2, itype4, tile_size) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, 4) \
|
||||
instantiate_fast_inference_sdpa_to_partials_kernel( \
|
||||
itype, itype2, itype4, tile_size, 8) // clang-format on
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
float,
|
||||
float2,
|
||||
float4,
|
||||
512);
|
||||
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
64);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
128);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
256);
|
||||
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
|
||||
half,
|
||||
half2,
|
||||
half4,
|
||||
512);
|
||||
|
||||
template <typename T>
|
||||
void fast_inference_sdpa_reduce_tiles_template(
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device T* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
constexpr const int DK = 128;
|
||||
const ulong offset_rows =
|
||||
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
|
||||
const device float* p_lse_row = p_lse + offset_rows;
|
||||
const device float* p_rowmax_row = p_maxes + offset_rows;
|
||||
// reserve some number of registers. this constitutes an assumption on max
|
||||
// value of KV TILES.
|
||||
constexpr const uint8_t reserve = 128;
|
||||
float p_lse_regs[reserve];
|
||||
float p_rowmax_regs[reserve];
|
||||
float weights[reserve];
|
||||
|
||||
float true_max = -INFINITY;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
p_lse_regs[i] = float(*(p_lse_row + i));
|
||||
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
|
||||
true_max = fmax(p_rowmax_regs[i], true_max);
|
||||
weights[i] = exp(p_lse_regs[i]);
|
||||
}
|
||||
|
||||
float denom = 0.f;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
weights[i] *= exp(p_rowmax_regs[i] - true_max);
|
||||
denom += weights[i];
|
||||
}
|
||||
|
||||
const device float* O_partials_with_offset = O_partials +
|
||||
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
|
||||
tid.x * DK * params.KV_TILES;
|
||||
|
||||
float o_value = 0.f;
|
||||
for (size_t i = 0; i < params.KV_TILES; i++) {
|
||||
float val = *(O_partials_with_offset + i * DK + lid.x);
|
||||
o_value += val * weights[i] / denom;
|
||||
}
|
||||
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
|
||||
O_gmem[lid.x] = T(o_value);
|
||||
return;
|
||||
}
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_float(
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device float* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
fast_inference_sdpa_reduce_tiles_template<float>(
|
||||
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||
}
|
||||
|
||||
kernel void fast_inference_sdpa_reduce_tiles_half(
|
||||
const device float* O_partials [[buffer(0)]],
|
||||
const device float* p_lse [[buffer(1)]],
|
||||
const device float* p_maxes [[buffer(2)]],
|
||||
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
|
||||
device half* O [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
fast_inference_sdpa_reduce_tiles_template<half>(
|
||||
O_partials, p_lse, p_maxes, params, O, tid, lid);
|
||||
}
|
||||
instantiate_sdpa_vector_heads(float)
|
||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_sdpa_vector_heads(float16_t)
|
||||
// clang-format on
|
||||
|
115
mlx/backend/metal/kernels/sdpa_vector.h
Normal file
115
mlx/backend/metal/kernels/sdpa_vector.h
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_simdgroup>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
const device T* keys [[buffer(1)]],
|
||||
const device T* values [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant float& scale,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
constexpr int elem_per_thread = D / BD;
|
||||
|
||||
const int stride = BN * D;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U q[elem_per_thread];
|
||||
thread U k[elem_per_thread];
|
||||
thread U o[elem_per_thread];
|
||||
|
||||
threadgroup U outputs[BN * BD];
|
||||
threadgroup U max_scores[BN];
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = static_cast<U>(scale) * queries[i];
|
||||
}
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U sum_exp_score = 0;
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += stride;
|
||||
values += stride;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
||||
// First let's communicate the max and sum_exp
|
||||
if (simd_lid == 0) {
|
||||
max_scores[simd_gid] = max_score;
|
||||
sum_exp_scores[simd_gid] = sum_exp_score;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
max_score = max_scores[simd_lid];
|
||||
U new_max = simd_max(max_score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// And write the output
|
||||
if (simd_lid == 0) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
out[i] = static_cast<T>(o[i]);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,20 +1,13 @@
|
||||
//
|
||||
// scaled_dot_product_attention.cpp
|
||||
// mlx
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@ -26,8 +19,7 @@ void sdpa_full_self_attention_metal(
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float alpha,
|
||||
array& out,
|
||||
std::vector<array>& temporaries) {
|
||||
array& out) {
|
||||
std::ostringstream kname_self_attention;
|
||||
kname_self_attention << "steel_gemm_attention_";
|
||||
|
||||
@ -148,130 +140,58 @@ void sdpa_full_self_attention_metal(
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void sdpa_metal(
|
||||
void sdpa_vector(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const array& p_lse,
|
||||
const array& p_rowmaxes,
|
||||
const array& o_partial,
|
||||
const uint heads,
|
||||
const uint tile_size,
|
||||
const uint n_tiles,
|
||||
const float alpha,
|
||||
array& out,
|
||||
std::vector<array>& temporaries) {
|
||||
std::ostringstream kname_partials;
|
||||
float scale) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
kname += "sdpa_vector_";
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
|
||||
kname_partials << "fast_inference_sdpa_compute_partials_";
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
int N = k.shape(2);
|
||||
int B = q.shape(0) * q.shape(1);
|
||||
size_t stride = k.strides()[1];
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
std::ostringstream kname_reduce;
|
||||
std::string delimiter = "_";
|
||||
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_partials << "float" + delimiter;
|
||||
kname_reduce << "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_partials << "half" + delimiter;
|
||||
kname_reduce << "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
|
||||
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
|
||||
|
||||
uint nsimd = 8;
|
||||
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
|
||||
|
||||
// maximum number of splits == 128 at the moment (reserved tile registers in
|
||||
// reduction kernel). this is arbitrary and could be changed in the shader.
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
constexpr const uint batch = 1;
|
||||
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
|
||||
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
compute_encoder.set_input_array(o_partial, 5);
|
||||
compute_encoder.set_input_array(p_lse, 6);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 7);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
// Launch
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
compute_encoder.set_input_array(o_partial, 0);
|
||||
compute_encoder.set_input_array(p_lse, 1);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
if (inputs.size() == 4) {
|
||||
out = fallback_(inputs)[0];
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@ -279,84 +199,75 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> temporaries;
|
||||
auto check_transpose = [&temporaries, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
|
||||
if (!predicate(arr)) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
temporaries.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
} else {
|
||||
return arr;
|
||||
}
|
||||
};
|
||||
|
||||
auto q = check_transpose(q_pre);
|
||||
auto k = check_transpose(k_pre);
|
||||
auto v = check_transpose(v_pre);
|
||||
// Checks if arr is fully row contiguous
|
||||
auto is_contiguous = [](const array& arr) {
|
||||
return arr.flags().row_contiguous;
|
||||
};
|
||||
|
||||
const int heads = q.shape(-3);
|
||||
// Returns true if the array is row contiguous except the sequence length
|
||||
// dimension that can be sliced but with step=1.
|
||||
auto is_contiguous_except_seq_len = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3] &&
|
||||
strides[0] == strides[1] * shape[1];
|
||||
};
|
||||
|
||||
uint query_sequence_length = q.shape(-2);
|
||||
if (query_sequence_length >= 16) {
|
||||
return sdpa_full_self_attention_metal(
|
||||
s, d, q, k, v, scale_, out, temporaries);
|
||||
}
|
||||
int tile_size = 64;
|
||||
const int kv_seq_len = k.shape(-2);
|
||||
if (kv_seq_len > 8000) {
|
||||
tile_size = 128;
|
||||
}
|
||||
if (kv_seq_len > 16000) {
|
||||
tile_size = 256;
|
||||
}
|
||||
if (kv_seq_len > 32000) {
|
||||
tile_size = 512;
|
||||
// Checks that the last two dims are row contiguous.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
auto& strides = arr.strides();
|
||||
auto& shape = arr.shape();
|
||||
return strides[3] == 1 && strides[2] == shape[3];
|
||||
};
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
auto q = copy_unless(is_contiguous, q_pre);
|
||||
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable()) {
|
||||
o.move_shared_buffer(q);
|
||||
} else {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
}
|
||||
|
||||
array o_partials(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
|
||||
float32,
|
||||
nullptr,
|
||||
{});
|
||||
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
|
||||
// Full attention mode
|
||||
else {
|
||||
auto q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
auto k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
auto v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
array p_lse(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
array p_rowmaxes(
|
||||
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
|
||||
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
|
||||
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
}
|
||||
|
||||
temporaries.push_back(p_lse);
|
||||
temporaries.push_back(p_rowmaxes);
|
||||
temporaries.push_back(o_partials);
|
||||
|
||||
return sdpa_metal(
|
||||
s,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
p_lse,
|
||||
p_rowmaxes,
|
||||
o_partials,
|
||||
heads,
|
||||
tile_size,
|
||||
n_tiles,
|
||||
scale_,
|
||||
out,
|
||||
temporaries);
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
34
mlx/fast.cpp
34
mlx/fast.cpp
@ -618,40 +618,38 @@ array scaled_dot_product_attention(
|
||||
};
|
||||
|
||||
auto stream = to_stream(s);
|
||||
const size_t value_head_dim = v.shape(-1);
|
||||
const size_t query_head_dim = q.shape(-1);
|
||||
const bool supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
|
||||
|
||||
const bool supported_head_dim_self_attn =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
const size_t query_sequence_length = q.shape(2);
|
||||
const bool supports_full_self_attention = query_sequence_length >= 16 &&
|
||||
!mask.has_value() && supported_head_dim_self_attn &&
|
||||
|
||||
bool implementation_supports_use_case = query_head_dim == value_head_dim;
|
||||
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
n_q_heads == n_kv_heads && final_type != bfloat16 &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
// fast decoding gpu shader
|
||||
bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 &&
|
||||
!mask.has_value() && supported_head_dim && final_type != bfloat16 &&
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
!mask.has_value() && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
bool implementation_supports_use_case =
|
||||
supports_sdpa || supports_full_self_attention;
|
||||
|
||||
// sdpa gpu shader is disabled except for memory efficient opt-in
|
||||
const int seq_for_threshold = queries.shape(2);
|
||||
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
|
||||
implementation_supports_use_case &= use_memory_efficient_impl;
|
||||
implementation_supports_use_case &=
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
auto out = array(
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
{q, k, v});
|
||||
return out;
|
||||
}
|
||||
|
||||
if (mask.has_value()) {
|
||||
|
Loading…
Reference in New Issue
Block a user