Metal shaders for memory efficient self attention on large sequences (#964)

* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
This commit is contained in:
Brian Keene 2024-06-03 12:16:19 -04:00 committed by GitHub
parent 3576b547c5
commit 1865299a30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1244 additions and 9 deletions

View File

@ -0,0 +1,64 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@ -1,9 +1,926 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
using namespace mlx::steel;
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderFA {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoaderFA(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
METAL_FUNC void next(short n) {
src += n * tile_stride;
}
};
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMAFA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
ushort sid;
ushort slid;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMAFA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
slid = simd_lane_id;
sid = simd_group_id;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
short row = sm + tm + i * TM_stride;
float scale_value = Corrections[row];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
accum[0] *= scale_value;
accum[1] *= scale_value;
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_to_tgp_memory(
threadgroup U* C,
const int ldc,
short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
METAL_FUNC void clear_results() {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
}
}
}
};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct FastAttentionKernel {
STEEL_CONST short tgp_padding = 16 / sizeof(T);
STEEL_CONST short float_padding = 16 / sizeof(float);
STEEL_CONST short tgp_mem_size_q =
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_k =
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_v =
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
// maxes, rowsums, rescale
STEEL_CONST short tgp_mem_size_corrections =
4 * (BM * sizeof(float) + float_padding);
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
STEEL_CONST short tgp_mem_size = share_kv_smem
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections + tgp_mem_size_v;
STEEL_CONST short tgp_size = WM * WN * 32;
static_assert(transpose_q == false, "Expected Q not transposed.");
static_assert(transpose_k == true, "Expected K transposed.");
static_assert(transpose_v == false, "Expected V not transposed.");
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
using loader_q_t = BlockLoaderFA<
T,
transpose_q ? BK : BM,
transpose_q ? BM : BK,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
!transpose_q,
tgp_size>;
using loader_k_t = BlockLoaderFA<
T,
transpose_k ? BN : BK,
transpose_k ? BK : BN,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
transpose_k,
tgp_size>;
using loader_v_t = BlockLoaderFA<
T,
transpose_v ? BK : BN,
transpose_v ? BN : BK,
transpose_v ? BN + tgp_padding : BK + tgp_padding,
transpose_v,
tgp_size>;
using mma_qk_t = BlockMMAFA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
AccumType,
Epilogue>;
using mma_sv_t = BlockMMAFA<
T,
U,
BM,
BK,
BN,
WM,
WN,
false,
transpose_v,
BN + tgp_padding,
BK + tgp_padding,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_k_t& loader_b,
thread mma_qk_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
(void)tgp_bm;
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
}
static METAL_FUNC void initialize_corrections(
threadgroup float* C,
uint simd_lane_id,
uint simd_group_id) {
if (simd_group_id == 0) {
threadgroup float* maxes = C;
threadgroup float* sums = C + (BM + float_padding);
threadgroup float* o_rescale = sums + (BM + float_padding);
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
if (simd_lane_id < BM) {
maxes[simd_lane_id] = -INFINITY; // m_i
sums[simd_lane_id] = 0.f; // l_i
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
}
}
}
static METAL_FUNC void rescale_ss(
threadgroup T* Ss,
threadgroup float* Corrections,
uint simd_group_id,
uint simd_lane_id,
short2 local_blocks,
float alpha) {
if (simd_group_id == 0) {
short row_offset = BM + float_padding;
threadgroup float* maxes = Corrections;
threadgroup float* sums = Corrections + row_offset;
threadgroup float* o_rescale = sums + row_offset;
threadgroup float* output_scales = o_rescale + row_offset;
if (simd_lane_id < uint(local_blocks.y)) {
float m_i_old = maxes[simd_lane_id];
float l_i_old = sums[simd_lane_id];
float m_i_new = m_i_old;
float l_i_new = l_i_old;
short offset = simd_lane_id * (BN + tgp_padding);
float m_ij = -INFINITY;
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
m_ij = max(m_ij, val);
}
m_i_new = max(m_ij, m_i_new);
float rowsum = 0.f; // lij
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
float P_i_j = exp(val - m_ij);
rowsum += P_i_j;
P_i_j = P_i_j * exp(m_ij - m_i_new);
Ss[offset + j] = T(P_i_j);
}
l_i_new =
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
maxes[simd_lane_id] = m_i_new;
sums[simd_lane_id] = l_i_new;
float rescale = l_i_old * exp(m_i_old - m_i_new);
o_rescale[simd_lane_id] = rescale;
output_scales[simd_lane_id] = 1.0 / l_i_new;
}
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device U* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
threadgroup T* Qs [[threadgroup(0)]],
threadgroup T* Ks [[threadgroup(1)]],
threadgroup T* Ss [[threadgroup(2)]],
threadgroup T* Vs [[threadgroup(3)]],
threadgroup float* Corrections [[threadgroup(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in Q, O; and head in K, V.
const int c_row = tid_y * BM;
Q += transpose_q ? c_row : c_row * params->ldq;
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_q.load_safe(tile_dims_Q);
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
O += c_row * params->ldo;
// Prepare threadgroup mma operation
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
n_block++) {
short c_col = BN;
// Prepare threadgroup loading operations
short gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
threadgroup_barrier(mem_flags::mem_none);
///////////////////////////////////////////////////////////////////////////////
{ // Loop over K - unaligned case
if (tgp_bm == BM && tgp_bn_qk == BN) {
gemm_loop<true, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bn_qk == BN) {
gemm_loop<false, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else {
gemm_loop<false, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
}
}
mma_qk_op.store_result_to_tgp_memory(
Ss, BN + tgp_padding, short2(BN, BM));
threadgroup_barrier(mem_flags::mem_threadgroup);
rescale_ss(
Ss,
Corrections,
simd_group_id,
simd_lane_id,
short2(tgp_bn_qk, tgp_bm),
params->alpha);
loader_v.load_safe(short2(BK, tgp_bn_qk));
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(o_scales);
mma_softmax_sv_op.mma(Ss, Vs);
threadgroup float* final_output_scales =
Corrections + 3 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(final_output_scales);
loader_v.next();
loader_k.next(BN);
mma_qk_op.clear_results();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
}
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using attention_kernel = FastAttentionKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_v,
MN_aligned,
K_aligned>;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* Q_bstrides = batch_strides;
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
Q += batch_offsets.x;
K += batch_offsets.y;
V += batch_offsets.y;
} else {
Q += params->batch_stride_q * tid.z;
K += params->batch_stride_k * tid.z;
V += params->batch_stride_v * tid.z;
}
// same shape as input
O += params->batch_stride_o * tid.z;
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
if (attention_kernel::share_kv_smem) {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
} else {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
}
}
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
"_itype_" #itype)]] [[kernel]] void \
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
64,
2,
2);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
128,
2,
2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
template <
typename T,
typename T2,

View File

@ -4,6 +4,34 @@
#pragma once
struct MLXFastAttentionParams {
const int M;
const int N;
const int K;
const int ldq; // ldq == ldo
const int ldk;
const int ldv;
const int lds;
const int ldo;
const int tiles_n;
const int tiles_m;
const int batch_stride_q;
const int batch_stride_k;
const int batch_stride_v;
const int batch_stride_o;
const int swizzle_log;
const int gemm_n_iterations_aligned;
const int gemm_k_iterations_aligned;
const int gemm_sv_m_block_iterations;
const int batch_ndim;
const float alpha;
};
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;

View File

@ -19,6 +19,140 @@
namespace mlx::core::fast {
namespace {
void sdpa_full_self_attention_metal(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& v,
const float alpha,
array& out,
std::vector<array>& temporaries) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
constexpr const int bm = 16;
constexpr const int bn = 16;
const int bk = q.shape(-1); // already forced to be 64 or 128
if (bk != 64 && bk != 128) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
}
constexpr const int wm = 2;
constexpr const int wn = 2;
std::string delimiter = "_";
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
if (q.dtype() == float32) {
kname_self_attention << "itype" + delimiter + "float";
} else if (q.dtype() == float16) {
kname_self_attention << "itype" + delimiter + "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_self_attention.str());
compute_encoder->setComputePipelineState(kernel);
uint hidden_dim = q.shape(-1);
uint qseq = q.shape(-2);
uint qheads = q.shape(-3);
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
const int M = q.shape(-2);
const int N = M;
const int K = q.shape(-1);
const size_t batch_size_out = q.shape(0) * q.shape(1);
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
const int dk = q.shape(-1);
const int ldq = dk;
const int ldk = dk;
const int ldv = dk;
const int lds = bn;
const int ldo = dk;
int tn = 1;
int tm = (M + bm - 1) / bm;
const int batch_stride_q = dk * query_sequence_length;
const int batch_stride_k = dk * query_sequence_length;
const int batch_stride_v = dk * query_sequence_length;
const int batch_stride_o = dk * query_sequence_length;
const int swizzle_log = 0;
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
const int batch_ndim = int(batch_shape.size());
MLXFastAttentionParams params{
(int)M,
(int)N,
(int)K,
ldq,
ldk,
ldv,
lds,
ldo,
tn,
tm,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
swizzle_log,
gemm_n_iterations_aligned,
gemm_k_iterations_aligned,
gemm_sv_m_block_iterations,
batch_ndim,
alpha};
const std::vector<size_t> batch_strides = {
(size_t)batch_stride_q,
(size_t)batch_stride_k,
(size_t)batch_stride_v,
(size_t)batch_stride_o};
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(MLXFastAttentionParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
void sdpa_metal(
const Stream& s,
@ -170,6 +304,12 @@ void ScaledDotProductAttention::eval_gpu(
auto v = check_transpose(v_pre);
const int heads = q.shape(-3);
uint query_sequence_length = q.shape(-2);
if (query_sequence_length >= 16) {
return sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, out, temporaries);
}
int tile_size = 64;
const int kv_seq_len = k.shape(-2);
if (kv_seq_len > 8000) {

View File

@ -527,11 +527,13 @@ array scaled_dot_product_attention(
/* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation
* * batch size > 1
* * query sequence length > 1
* * batch size > 1 for decoding or causal attention
* * query sequence length > 1 for decoding
* * query sequence length > 16 && non-null mask (causal attention)
* * non-null mask
* * dtype is not fp32 or fp16
*/
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
@ -559,15 +561,29 @@ array scaled_dot_product_attention(
};
auto stream = to_stream(s);
constexpr const int supported_head_dim = 128;
const size_t query_head_dim = q.shape(-1);
const bool supported_head_dim =
query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128;
const bool supported_head_dim_self_attn =
query_head_dim == 64 || query_head_dim == 128;
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = batch_dim == 1 &&
query_sequence_length == 1 && !mask.has_value() &&
query_head_dim == supported_head_dim && final_type != bfloat16 &&
const bool supports_full_self_attention = query_sequence_length >= 16 &&
!mask.has_value() && supported_head_dim_self_attn &&
n_q_heads == n_kv_heads && final_type != bfloat16 &&
stream.device == Device::gpu;
// TODO, update routing conditions post further tuning
// fast decoding gpu shader
bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 &&
!mask.has_value() && supported_head_dim && final_type != bfloat16 &&
stream.device == Device::gpu;
bool implementation_supports_use_case =
supports_sdpa || supports_full_self_attention;
// disabling full self attention until perf is tuned;
// likewise for sdpa
implementation_supports_use_case &= false;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});

View File

@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) {
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");

View File

@ -32,9 +32,80 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
return mlx_primitives_sdpa(q, k, v, scale)
class TestFastSDPA(mlx_tests.MLXTestCase):
class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)
R = 20
L = R
Dk = 64
H = 3
scale = float(1.0 / np.sqrt(Dk))
q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32)
k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32)
q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)
reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
dtypes = [np.float32]
Dk = 64
if self.is_apple_silicon:
dtypes.append(np.half)
for SEQUENCE_LENGTH in [63, 129, 400]:
for DTYPE in dtypes:
B = 2
H = 24
n_kv_heads = H
q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype(
DTYPE
)
k_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)
v_npy = np.random.normal(
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
).astype(DTYPE)
q_mlx = mx.array(q_npy)
k_mlx = mx.array(k_npy)
v_mlx = mx.array(v_npy)
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
rtol = 1e-3
atol = 1e-2
if SEQUENCE_LENGTH > 500:
rtol = 1e-2
if DTYPE == np.half:
rtol = 1e-2
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
np.random.seed(0)