[WIP]: Loading and Matmuls added

This commit is contained in:
Jagrit Digani 2024-11-19 10:24:52 -08:00
parent ad5b58b34e
commit 168a3a464a
6 changed files with 295 additions and 17 deletions

View File

@ -19,6 +19,16 @@ constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
// clang-format off
template <
typename T,
@ -55,11 +65,151 @@ template <
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D;
i += WM * WN * 32) {
int r = i / params->D;
int c = i % params->D;
constexpr int padQ = 0; // 16 / sizeof(T);
constexpr int padK = 0; // 16 / sizeof(T);
constexpr int padV = 0; // 16 / sizeof(T);
O[params->O_strides[2] * r + c] = T(0);
// using QBlockSrcShape = CShape<BQ, BD>;
// using KBlockSrcShape = CShape<BK, BD>;
// using VBlockSrcShape = CShape<BK, BD>;
constexpr int LDQ_tgp = BD + padQ;
constexpr int LDK_tgp = BK + padK;
constexpr int LDV_tgp = BD + padV;
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
// using KBlockDstStrides = CShape<1, LDK_tgp>;
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale));
// MMAFrag size
constexpr short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp
constexpr int TK =
BK / kFragSize; // KV sequence frags (all warps load the same frags)
constexpr int TD =
BD / kFragSize; // HeadDim frags (all warps load the same frags)
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
short sm = simd_coord.y;
short sn = simd_coord.x;
short tm = kFragSize * TQ * simd_group_id;
short Qs_offset = (tm + sm) * LDQ_tgp + sn;
short Ks_offset = sm * LDK_tgp + sn;
short Vs_offset = sm * LDV_tgp + sn;
constexpr int Qs_tile_stride = kFragSize;
constexpr int Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_q.load_unsafe();
loader_q.apply_inplace_op(ts);
for (int kb = 0; kb < params->NK; kb++) {
// Load Q and K blocks and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_k.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T
Stile.clear();
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
loader_v.load_unsafe();
// Do softmax
// Do O = S @ V
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
// loader_q.next();
loader_k.next();
loader_v.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}

View File

@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
// #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
@ -21,11 +21,12 @@
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@ -133,5 +133,132 @@ struct BlockLoader {
}
};
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@ -358,12 +358,12 @@ METAL_FUNC void tile_matmad(
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),

View File

@ -385,9 +385,9 @@ struct BlockMMA {
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
STEEL_CONST short TM = BM / (kFragSize * WM);
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
STEEL_CONST short TN = BN / (kFragSize * WN);
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M

View File

@ -26,10 +26,10 @@ void sdpa_full_self_attention_metal(
array& o) {
using namespace mlx::steel;
int wm = 2;
int wn = 2;
int wm = 4;
int wn = 1;
int bq = 16;
int bq = 32;
int bk = 16;
int bd = 64;