mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Matrix Attention kernel (#1610)
* Rough INIT * [WIP]: Loading and Matmuls added * [WIP]: Reductions and min working aligned kernel at headdim = 64 * [WIP] Added headdim 80 for testing * [WIP] Update dispatch params for testing * [WIP] Add support for unaligned seq lengths - still looks messy * Update sdpa_benchmarks * Update sdpa_benchmarks * Update sdpa_benchmarks * Enable gqa support * Update benchmark and switch off 128 headdim * Update headdim 128 tuning * Remove older fast attention code. Write out O strided * Disable hd=128 until further optimizations * Enable bf16 * Fix data size bug * Enable attn build outside of jit
This commit is contained in:
parent
c79f6a4a8c
commit
02bec0bb6d
@ -1,62 +1,189 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
@ -44,9 +44,7 @@ build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention scaled_dot_product_attention_params.h
|
||||
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
@ -68,6 +66,24 @@ set(STEEL_HEADERS
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
|
@ -1,930 +1,11 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoaderFA {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoaderFA(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
METAL_FUNC void next(short n) {
|
||||
src += n * tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMAFA {
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = 8 * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
STEEL_CONST short simd_stride_a = {
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||
STEEL_CONST short simd_stride_b = {
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||
|
||||
// Jump between elements
|
||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||
|
||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
// Offsets within threadgroup
|
||||
const short tm;
|
||||
const short tn;
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
ushort sid;
|
||||
ushort slid;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMAFA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short qid = simd_lane_id / 4;
|
||||
slid = simd_lane_id;
|
||||
sid = simd_group_id;
|
||||
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset =
|
||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||
Bs_offset =
|
||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of 8
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||
Asimd[i].thread_elements()[1] =
|
||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||
Bsimd[j].thread_elements()[1] =
|
||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j_serp],
|
||||
Asimd[i],
|
||||
Bsimd[j_serp],
|
||||
results[i * TN + j_serp]);
|
||||
}
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
|
||||
// Loop over all simdgroup tiles
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
short row = sm + tm + i * TM_stride;
|
||||
float scale_value = Corrections[row];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
accum[0] *= scale_value;
|
||||
accum[1] *= scale_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_to_tgp_memory(
|
||||
threadgroup U* C,
|
||||
const int ldc,
|
||||
short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {
|
||||
epilogue_op.apply(accum[0], C[offset_c]),
|
||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||
|
||||
// Write out D
|
||||
D[offset_d] = outs[0];
|
||||
D[offset_d + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void clear_results() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_q,
|
||||
bool transpose_k,
|
||||
bool transpose_v,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct FastAttentionKernel {
|
||||
STEEL_CONST short tgp_padding = 16 / sizeof(T);
|
||||
STEEL_CONST short float_padding = 16 / sizeof(float);
|
||||
STEEL_CONST short tgp_mem_size_q =
|
||||
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_k =
|
||||
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_v =
|
||||
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
|
||||
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
|
||||
|
||||
// maxes, rowsums, rescale
|
||||
STEEL_CONST short tgp_mem_size_corrections =
|
||||
4 * (BM * sizeof(float) + float_padding);
|
||||
|
||||
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
|
||||
|
||||
STEEL_CONST short tgp_mem_size = share_kv_smem
|
||||
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
|
||||
tgp_mem_size_corrections
|
||||
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
|
||||
tgp_mem_size_corrections + tgp_mem_size_v;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
static_assert(transpose_q == false, "Expected Q not transposed.");
|
||||
static_assert(transpose_k == true, "Expected K transposed.");
|
||||
static_assert(transpose_v == false, "Expected V not transposed.");
|
||||
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
|
||||
|
||||
using loader_q_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_q ? BK : BM,
|
||||
transpose_q ? BM : BK,
|
||||
transpose_q ? BM + tgp_padding : BK + tgp_padding,
|
||||
!transpose_q,
|
||||
tgp_size>;
|
||||
|
||||
using loader_k_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_k ? BN : BK,
|
||||
transpose_k ? BK : BN,
|
||||
transpose_k ? BK + tgp_padding : BN + tgp_padding,
|
||||
transpose_k,
|
||||
tgp_size>;
|
||||
|
||||
using loader_v_t = BlockLoaderFA<
|
||||
T,
|
||||
transpose_v ? BK : BN,
|
||||
transpose_v ? BN : BK,
|
||||
transpose_v ? BN + tgp_padding : BK + tgp_padding,
|
||||
transpose_v,
|
||||
tgp_size>;
|
||||
|
||||
using mma_qk_t = BlockMMAFA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_q,
|
||||
transpose_k,
|
||||
transpose_q ? BM + tgp_padding : BK + tgp_padding,
|
||||
transpose_k ? BK + tgp_padding : BN + tgp_padding,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
using mma_sv_t = BlockMMAFA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BK,
|
||||
BN,
|
||||
WM,
|
||||
WN,
|
||||
false,
|
||||
transpose_v,
|
||||
BN + tgp_padding,
|
||||
BK + tgp_padding,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_k_t& loader_b,
|
||||
thread mma_qk_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
(void)tgp_bm;
|
||||
|
||||
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void initialize_corrections(
|
||||
threadgroup float* C,
|
||||
uint simd_lane_id,
|
||||
uint simd_group_id) {
|
||||
if (simd_group_id == 0) {
|
||||
threadgroup float* maxes = C;
|
||||
threadgroup float* sums = C + (BM + float_padding);
|
||||
threadgroup float* o_rescale = sums + (BM + float_padding);
|
||||
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
|
||||
|
||||
if (simd_lane_id < BM) {
|
||||
maxes[simd_lane_id] = -INFINITY; // m_i
|
||||
sums[simd_lane_id] = 0.f; // l_i
|
||||
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
|
||||
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static METAL_FUNC void rescale_ss(
|
||||
threadgroup T* Ss,
|
||||
threadgroup float* Corrections,
|
||||
uint simd_group_id,
|
||||
uint simd_lane_id,
|
||||
short2 local_blocks,
|
||||
float alpha) {
|
||||
if (simd_group_id == 0) {
|
||||
short row_offset = BM + float_padding;
|
||||
threadgroup float* maxes = Corrections;
|
||||
threadgroup float* sums = Corrections + row_offset;
|
||||
threadgroup float* o_rescale = sums + row_offset;
|
||||
threadgroup float* output_scales = o_rescale + row_offset;
|
||||
|
||||
if (simd_lane_id < uint(local_blocks.y)) {
|
||||
float m_i_old = maxes[simd_lane_id];
|
||||
float l_i_old = sums[simd_lane_id];
|
||||
|
||||
float m_i_new = m_i_old;
|
||||
float l_i_new = l_i_old;
|
||||
|
||||
short offset = simd_lane_id * (BN + tgp_padding);
|
||||
|
||||
float m_ij = -INFINITY;
|
||||
|
||||
for (short j = 0; j < local_blocks.x; j++) {
|
||||
float val = alpha * float(Ss[offset + j]);
|
||||
m_ij = max(m_ij, val);
|
||||
}
|
||||
|
||||
m_i_new = max(m_ij, m_i_new);
|
||||
|
||||
float rowsum = 0.f; // lij
|
||||
|
||||
for (short j = 0; j < local_blocks.x; j++) {
|
||||
float val = alpha * float(Ss[offset + j]);
|
||||
float P_i_j = exp(val - m_ij);
|
||||
rowsum += P_i_j;
|
||||
P_i_j = P_i_j * exp(m_ij - m_i_new);
|
||||
Ss[offset + j] = T(P_i_j);
|
||||
}
|
||||
|
||||
l_i_new =
|
||||
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
|
||||
maxes[simd_lane_id] = m_i_new;
|
||||
sums[simd_lane_id] = l_i_new;
|
||||
float rescale = l_i_old * exp(m_i_old - m_i_new);
|
||||
o_rescale[simd_lane_id] = rescale;
|
||||
output_scales[simd_lane_id] = 1.0 / l_i_new;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device U* O [[buffer(3)]],
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]],
|
||||
threadgroup T* Qs [[threadgroup(0)]],
|
||||
threadgroup T* Ks [[threadgroup(1)]],
|
||||
threadgroup T* Ss [[threadgroup(2)]],
|
||||
threadgroup T* Vs [[threadgroup(3)]],
|
||||
threadgroup float* Corrections [[threadgroup(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in Q, O; and head in K, V.
|
||||
const int c_row = tid_y * BM;
|
||||
|
||||
Q += transpose_q ? c_row : c_row * params->ldq;
|
||||
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
loader_q.load_safe(tile_dims_Q);
|
||||
|
||||
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
|
||||
|
||||
O += c_row * params->ldo;
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
|
||||
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
|
||||
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
|
||||
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
|
||||
n_block++) {
|
||||
short c_col = BN;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
short gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
{ // Loop over K - unaligned case
|
||||
|
||||
if (tgp_bm == BM && tgp_bn_qk == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
} else if (tgp_bn_qk == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
Qs,
|
||||
Ks,
|
||||
gemm_k_iterations,
|
||||
loader_k,
|
||||
mma_qk_op,
|
||||
tgp_bm,
|
||||
tgp_bn_qk);
|
||||
}
|
||||
}
|
||||
|
||||
mma_qk_op.store_result_to_tgp_memory(
|
||||
Ss, BN + tgp_padding, short2(BN, BM));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
rescale_ss(
|
||||
Ss,
|
||||
Corrections,
|
||||
simd_group_id,
|
||||
simd_lane_id,
|
||||
short2(tgp_bn_qk, tgp_bm),
|
||||
params->alpha);
|
||||
|
||||
loader_v.load_safe(short2(BK, tgp_bn_qk));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
|
||||
mma_softmax_sv_op.rescale_output(o_scales);
|
||||
|
||||
mma_softmax_sv_op.mma(Ss, Vs);
|
||||
|
||||
threadgroup float* final_output_scales =
|
||||
Corrections + 3 * (BM + float_padding);
|
||||
|
||||
mma_softmax_sv_op.rescale_output(final_output_scales);
|
||||
|
||||
loader_v.next();
|
||||
loader_k.next(BN);
|
||||
|
||||
mma_qk_op.clear_results();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_q,
|
||||
bool transpose_k,
|
||||
bool transpose_v,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using attention_kernel = FastAttentionKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_q,
|
||||
transpose_k,
|
||||
transpose_v,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* Q_bstrides = batch_strides;
|
||||
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
|
||||
|
||||
Q += batch_offsets.x;
|
||||
K += batch_offsets.y;
|
||||
V += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
Q += params->batch_stride_q * tid.z;
|
||||
K += params->batch_stride_k * tid.z;
|
||||
V += params->batch_stride_v * tid.z;
|
||||
}
|
||||
|
||||
// same shape as input
|
||||
O += params->batch_stride_o * tid.z;
|
||||
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
|
||||
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
|
||||
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
|
||||
|
||||
if (attention_kernel::share_kv_smem) {
|
||||
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
|
||||
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
|
||||
attention_kernel::run(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
params,
|
||||
Qs,
|
||||
Ks,
|
||||
Ss,
|
||||
Vs,
|
||||
Corrections,
|
||||
simd_lane_id,
|
||||
simd_group_id,
|
||||
tid,
|
||||
lid);
|
||||
} else {
|
||||
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
|
||||
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
|
||||
attention_kernel::run(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
O,
|
||||
params,
|
||||
Qs,
|
||||
Ks,
|
||||
Ss,
|
||||
Vs,
|
||||
Corrections,
|
||||
simd_lane_id,
|
||||
simd_group_id,
|
||||
tid,
|
||||
lid);
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
|
||||
// SDPA full instantiations
|
||||
#define instantiate_fast_inference_self_attention_kernel( \
|
||||
itype, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
|
||||
"_itype_" #itype)]] [[kernel]] void \
|
||||
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
|
||||
const device itype* Q [[buffer(0)]], \
|
||||
const device itype* K [[buffer(1)]], \
|
||||
const device itype* V [[buffer(2)]], \
|
||||
device otype* O [[buffer(3)]], \
|
||||
const constant MLXFastAttentionParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
instantiate_fast_inference_self_attention_kernel(
|
||||
float,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
64,
|
||||
2,
|
||||
2);
|
||||
instantiate_fast_inference_self_attention_kernel(
|
||||
float,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
128,
|
||||
2,
|
||||
2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
|
||||
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
|
||||
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector(type, head_dim) \
|
||||
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
|
||||
|
@ -1,42 +0,0 @@
|
||||
//
|
||||
// scaled_dot_product_attention_params.h
|
||||
// mlx
|
||||
|
||||
#pragma once
|
||||
|
||||
struct MLXFastAttentionParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int ldq; // ldq == ldo
|
||||
const int ldk;
|
||||
const int ldv;
|
||||
const int lds;
|
||||
const int ldo;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_q;
|
||||
const int batch_stride_k;
|
||||
const int batch_stride_v;
|
||||
const int batch_stride_o;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_n_iterations_aligned;
|
||||
const int gemm_k_iterations_aligned;
|
||||
const int gemm_sv_m_block_iterations;
|
||||
|
||||
const int batch_ndim;
|
||||
const float alpha;
|
||||
};
|
||||
|
||||
struct MLXScaledDotProductAttentionParams {
|
||||
// Associated dimensions & transposition information
|
||||
const uint QUERY_SEQUENCE_LENGTH = 1;
|
||||
const uint N_Q_HEADS = 32;
|
||||
const uint N_KV_HEADS = 32;
|
||||
const uint KV_TILES = 1;
|
||||
const float INV_ALPHA = 0.08838834764831843f;
|
||||
};
|
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
@ -0,0 +1,296 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* D [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
349
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Normal file
349
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Normal file
@ -0,0 +1,349 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||
|
||||
METAL_FUNC T apply(T x) const {
|
||||
return scale * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct SumOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct SubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BQ,
|
||||
int BK,
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
// Move to correct block
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
kv_head_idx * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
kv_head_idx * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
// Prepare threadgroup memory
|
||||
constexpr short padQ = 0; // 16 / sizeof(T);
|
||||
constexpr short padK = 0; // 16 / sizeof(T);
|
||||
constexpr short padV = 0; // 16 / sizeof(T);
|
||||
|
||||
constexpr short LDQ_tgp = BD + padQ;
|
||||
constexpr short LDK_tgp = BK + padK;
|
||||
constexpr short LDV_tgp = BD + padV;
|
||||
|
||||
threadgroup T Qs[BQ * (BD + padQ)];
|
||||
threadgroup T Ks[(BK + padK) * BD];
|
||||
threadgroup T Vs[BK * (BD + padV)];
|
||||
|
||||
// Prepare block loaders
|
||||
using QBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BQ,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDQ_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 1,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
// K is loaded in transposed
|
||||
using KBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ 1,
|
||||
/* short kDstStrCol = */ LDK_tgp,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
using VBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDV_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
QBlockLoader loader_q(
|
||||
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
|
||||
KBlockLoader loader_k(
|
||||
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
static_assert(
|
||||
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
// Q seq frags per warp
|
||||
constexpr int TQ = BQ / (kNWarps * kFragSize);
|
||||
// KV sequence frags (all warps load the same frags)
|
||||
constexpr int TK = BK / kFragSize;
|
||||
// HeadDim frags (all warps load the same frags)
|
||||
constexpr int TD = BD / kFragSize;
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
||||
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
||||
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
|
||||
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
||||
|
||||
Otile.clear();
|
||||
|
||||
// Prepare mma tile offsets
|
||||
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
const short tm = kFragSize * TQ * simd_group_id;
|
||||
|
||||
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||
const short Ks_offset = sm * LDK_tgp + sn;
|
||||
const short Vs_offset = sm * LDV_tgp + sn;
|
||||
|
||||
constexpr short Qs_tile_stride = kFragSize;
|
||||
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load Q blocks apply scale
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
|
||||
} else {
|
||||
loader_q.load_unsafe();
|
||||
}
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
// Init row reduction variables
|
||||
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
|
||||
|
||||
AccumType max_score[kRowsPT];
|
||||
AccumType sum_score[kRowsPT] = {0};
|
||||
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
}
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load K block and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_k.load_unsafe();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do S = Q @ K.T
|
||||
Stile.clear();
|
||||
|
||||
for (short dd = 0; dd < TD; dd++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
|
||||
&Qs[Qs_offset + dd * Qs_tile_stride]);
|
||||
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
|
||||
&Ks[Ks_offset + dd * Ks_tile_stride]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||
}
|
||||
|
||||
// Mask out of length sequence
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
const short lim = params->kL - params->NK_aligned * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < stile_t::kTileCols; j++) {
|
||||
short col_pos = sn + (j * stile_t::kFragCols);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
|
||||
if ((col_pos + jj) >= lim) {
|
||||
Stile.frag_at(i, j)[jj] = neg_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load V blocks
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
|
||||
} else {
|
||||
loader_v.load_unsafe();
|
||||
}
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Temp variables
|
||||
AccumType new_max[kRowsPT];
|
||||
AccumType factor[kRowsPT];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
// Row max
|
||||
Stile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
Stile.template row_bin_op<ExpSubOp>(new_max);
|
||||
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||
}
|
||||
|
||||
// Save max for next iteration
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = new_max[i];
|
||||
}
|
||||
|
||||
// Row Sum
|
||||
AccumType sum_score_tmp[kRowsPT] = {0};
|
||||
Stile.template row_reduce<SumOp>(sum_score_tmp);
|
||||
|
||||
// Update norm
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
|
||||
}
|
||||
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
// Load V into registers
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do O = S @ V
|
||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_k.next();
|
||||
loader_v.next();
|
||||
}
|
||||
|
||||
// Normalize output
|
||||
Otile.template row_bin_op<DivOp>(sum_score);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * params->O_strides[2] + sn;
|
||||
|
||||
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
|
||||
auto dst_tile_dims =
|
||||
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
|
||||
} else {
|
||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
|
||||
const device dtype* Q [[buffer(0)]], \
|
||||
const device dtype* K [[buffer(1)]], \
|
||||
const device dtype* V [[buffer(2)]], \
|
||||
device dtype* O [[buffer(3)]],\
|
||||
const constant AttnParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
// clang-format on
|
264
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
264
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
@ -0,0 +1,264 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
template <int R, int C>
|
||||
struct CShape {
|
||||
STEEL_CONST int kRows = R;
|
||||
STEEL_CONST int kCols = C;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short kDstStrRow,
|
||||
short kDstStrCol,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoaderT {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoaderT(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] =
|
||||
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
726
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
726
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
@ -0,0 +1,726 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename RInt, typename CInt>
|
||||
struct Shape2D {
|
||||
RInt r;
|
||||
CInt c;
|
||||
|
||||
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
|
||||
};
|
||||
|
||||
template <typename Shape, typename Layout>
|
||||
struct Layout2D {
|
||||
Shape shape;
|
||||
Layout layout;
|
||||
};
|
||||
|
||||
template <typename T, int kFragRows_, int kFragCols_>
|
||||
struct BaseMMAFrag {
|
||||
static_assert(
|
||||
kFragRows_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
static_assert(
|
||||
kFragCols_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BaseMMAFrag<T, 8, 8> {
|
||||
STEEL_CONST int kFragRows = 8;
|
||||
STEEL_CONST int kFragCols = 8;
|
||||
|
||||
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
||||
|
||||
STEEL_CONST int kElemRows = 1;
|
||||
STEEL_CONST int kElemCols = 2;
|
||||
|
||||
static_assert(
|
||||
kElemRows * kElemCols == kElemsPerFrag,
|
||||
"MMAFrag shape is not consistent with MMAFrag size");
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
||||
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
return short2{fn, fm};
|
||||
}
|
||||
|
||||
template <typename SrcPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename SrcPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void load_safe(
|
||||
thread frag_type& dst,
|
||||
SrcPtrType src,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_safe(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_reduce(
|
||||
thread const frag_type& inp_vals,
|
||||
thread T* reduced_vals) {
|
||||
T thr_reduce = Op::apply(inp_vals.x, inp_vals.y);
|
||||
|
||||
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
|
||||
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
|
||||
|
||||
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
|
||||
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
|
||||
|
||||
reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC static constexpr void row_bin_op(
|
||||
thread frag_type& inp_vals,
|
||||
thread T* row_vals) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
inp_vals[i * kElemCols + j] =
|
||||
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int kTileRows_,
|
||||
int kTileCols_,
|
||||
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
||||
struct MMATile {
|
||||
using MMAFrag_t = MMAFrag_;
|
||||
using elem_type = T;
|
||||
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
||||
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
||||
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kTileRows = kTileRows_;
|
||||
STEEL_CONST int kTileCols = kTileCols_;
|
||||
|
||||
STEEL_CONST int kRows = kTileRows * kFragRows;
|
||||
STEEL_CONST int kCols = kTileCols * kFragCols;
|
||||
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows;
|
||||
STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
METAL_FUNC constexpr void clear() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kNumFrags; ++i) {
|
||||
val_frags[i] = frag_type(0);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr const thread frag_type& frag_at(
|
||||
const short i,
|
||||
const short j) const {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
||||
mat_type val_mat;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
||||
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
||||
}
|
||||
return val_mat;
|
||||
}
|
||||
|
||||
METAL_FUNC thread elem_type* elems() {
|
||||
return reinterpret_cast<thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
METAL_FUNC const thread elem_type* elems() const {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_reduce<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::template row_bin_op<Op>(
|
||||
frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
src[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void store(threadgroup U* dst) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
dst[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void load(const device U* src, const int ld) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store(device U* dst, const int ld) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load_safe(
|
||||
frag_at(i, j),
|
||||
src,
|
||||
ld,
|
||||
Int<1>{},
|
||||
src_tile_dims.y,
|
||||
src_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_safe(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
dst_tile_dims.y,
|
||||
dst_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// Simdgroup matrices
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
||||
|
||||
// Offsets within threadgroup
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Ctile, Atile, Btile, Ctile);
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename UnaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Read C
|
||||
U c_elems[kelems] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
c_elems[k] = C[offset_c + k * fdc];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[offset_d + k] =
|
||||
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
36
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
36
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Attn param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct AttnParams {
|
||||
int B; ///< Batch Size
|
||||
int H; ///< Heads
|
||||
int D; ///< Head Dim
|
||||
|
||||
int qL; ///< Query Sequence Length
|
||||
int kL; ///< Key Sequence Length
|
||||
|
||||
int gqa_factor; ///< Group Query factor
|
||||
float scale; ///< Attention scale
|
||||
|
||||
int NQ; ///< Number of query blocks
|
||||
int NK; ///< Number of key/value blocks
|
||||
|
||||
int NQ_aligned; ///< Number of full query blocks
|
||||
int NK_aligned; ///< Number of full key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -385,9 +385,9 @@ struct BlockMMA {
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
STEEL_CONST short TM = BM / (kFragSize * WM);
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
STEEL_CONST short TN = BN / (kFragSize * WN);
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
|
@ -6,7 +6,9 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -19,122 +21,89 @@ void sdpa_full_self_attention_metal(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float alpha,
|
||||
array& out) {
|
||||
std::ostringstream kname_self_attention;
|
||||
kname_self_attention << "steel_gemm_attention_";
|
||||
const float scale,
|
||||
array& o) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
constexpr const int bm = 16;
|
||||
constexpr const int bn = 16;
|
||||
const int bk = q.shape(-1); // already forced to be 64 or 128
|
||||
int wm = 4;
|
||||
int wn = 1;
|
||||
|
||||
if (bk != 64 && bk != 128) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
|
||||
}
|
||||
int bd = q.shape(-1);
|
||||
int bq = 32;
|
||||
int bk = bd < 128 ? 32 : 16;
|
||||
|
||||
constexpr const int wm = 2;
|
||||
constexpr const int wn = 2;
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
std::string delimiter = "_";
|
||||
int qL = q.shape(2);
|
||||
int kL = k.shape(2);
|
||||
|
||||
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
|
||||
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
|
||||
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
|
||||
const bool align_Q = (qL % bq) == 0;
|
||||
const bool align_K = (kL % bk) == 0;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||
};
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_self_attention << "itype" + delimiter + "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_self_attention << "itype" + delimiter + "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
std::ostringstream kname;
|
||||
// clang-format off
|
||||
kname << "steel_attention_"
|
||||
<< type_to_name(q)
|
||||
<< "_bq" << bq
|
||||
<< "_bk" << bk
|
||||
<< "_bd" << bd
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
// clang-format off
|
||||
kname << "_align_Q_" << (align_Q ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_self_attention.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
uint hidden_dim = q.shape(-1);
|
||||
uint qseq = q.shape(-2);
|
||||
uint qheads = q.shape(-3);
|
||||
const int NQ = (qL + bq - 1) / bq;
|
||||
const int NK = (kL + bk - 1) / bk;
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
const int NQ_aligned = qL / bq;
|
||||
const int NK_aligned = kL / bk;
|
||||
|
||||
const int M = q.shape(-2);
|
||||
const int N = M;
|
||||
const int K = q.shape(-1);
|
||||
const size_t batch_size_out = q.shape(0) * q.shape(1);
|
||||
AttnParams params{
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int D = */ D,
|
||||
|
||||
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
|
||||
const int dk = q.shape(-1);
|
||||
const int ldq = dk;
|
||||
const int ldk = dk;
|
||||
const int ldv = dk;
|
||||
const int lds = bn;
|
||||
const int ldo = dk;
|
||||
/* int qL = */ qL,
|
||||
/* int kL = */ kL,
|
||||
|
||||
int tn = 1;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
const int batch_stride_q = dk * query_sequence_length;
|
||||
const int batch_stride_k = dk * query_sequence_length;
|
||||
const int batch_stride_v = dk * query_sequence_length;
|
||||
const int batch_stride_o = dk * query_sequence_length;
|
||||
const int swizzle_log = 0;
|
||||
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
|
||||
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
|
||||
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
|
||||
const int batch_ndim = int(batch_shape.size());
|
||||
/* int NQ = */ NQ,
|
||||
/* int NK = */ NK,
|
||||
|
||||
MLXFastAttentionParams params{
|
||||
(int)M,
|
||||
(int)N,
|
||||
(int)K,
|
||||
ldq,
|
||||
ldk,
|
||||
ldv,
|
||||
lds,
|
||||
ldo,
|
||||
tn,
|
||||
tm,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
swizzle_log,
|
||||
gemm_n_iterations_aligned,
|
||||
gemm_k_iterations_aligned,
|
||||
gemm_sv_m_block_iterations,
|
||||
batch_ndim,
|
||||
alpha};
|
||||
/* int NQ_aligned = */ NQ_aligned,
|
||||
/* int NK_aligned = */ NK_aligned,
|
||||
|
||||
const std::vector<size_t> batch_strides = {
|
||||
(size_t)batch_stride_q,
|
||||
(size_t)batch_stride_k,
|
||||
(size_t)batch_stride_v,
|
||||
(size_t)batch_stride_o};
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@ -356,7 +325,24 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
||||
size_t str_oD = 1;
|
||||
size_t str_oH = o.shape(3);
|
||||
size_t str_oL = o.shape(1) * str_oH;
|
||||
size_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ 0,
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc_or_wait(o.nbytes()),
|
||||
data_size,
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
|
||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||
}
|
||||
|
@ -600,7 +600,7 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1e6;
|
||||
int threshold = 32; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
@ -644,11 +644,10 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_vector_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
||||
const bool sdpa_full_supported_head_dim =
|
||||
query_head_dim == 64 || query_head_dim == 128;
|
||||
query_head_dim == 64 || query_head_dim == 80;
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
||||
n_q_heads == n_kv_heads && final_type != bfloat16 &&
|
||||
stream.device == Device::gpu;
|
||||
|
||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||
|
Loading…
Reference in New Issue
Block a user