mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
[WIP]: Loading and Matmuls added
This commit is contained in:
parent
ad5b58b34e
commit
168a3a464a
@ -19,6 +19,16 @@ constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||
|
||||
METAL_FUNC T apply(T x) const {
|
||||
return scale * x;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
@ -55,11 +65,151 @@ template <
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D;
|
||||
i += WM * WN * 32) {
|
||||
int r = i / params->D;
|
||||
int c = i % params->D;
|
||||
constexpr int padQ = 0; // 16 / sizeof(T);
|
||||
constexpr int padK = 0; // 16 / sizeof(T);
|
||||
constexpr int padV = 0; // 16 / sizeof(T);
|
||||
|
||||
O[params->O_strides[2] * r + c] = T(0);
|
||||
// using QBlockSrcShape = CShape<BQ, BD>;
|
||||
// using KBlockSrcShape = CShape<BK, BD>;
|
||||
// using VBlockSrcShape = CShape<BK, BD>;
|
||||
|
||||
constexpr int LDQ_tgp = BD + padQ;
|
||||
constexpr int LDK_tgp = BK + padK;
|
||||
constexpr int LDV_tgp = BD + padV;
|
||||
|
||||
// using QBlockDstStrides = CShape<LDQ_tgp, 1>;
|
||||
// using KBlockDstStrides = CShape<1, LDK_tgp>;
|
||||
// using QBlockDstStrides = CShape<LDV_tgp, 1>;
|
||||
|
||||
threadgroup T Qs[BQ * (BD + padQ)];
|
||||
threadgroup T Ks[(BK + padK) * BD];
|
||||
threadgroup T Vs[BK * (BD + padV)];
|
||||
|
||||
using QBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BQ,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDQ_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 1,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
using KBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ 1,
|
||||
/* short kDstStrCol = */ LDK_tgp,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
using VBlockLoader = BlockLoaderT<
|
||||
/* typename T = */ T,
|
||||
/* short BROWS = */ BK,
|
||||
/* short BCOLS = */ BD,
|
||||
/* short kDstStrRow = */ LDV_tgp,
|
||||
/* short kDstStrCol = */ 1,
|
||||
/* short reduction_dim = */ 0,
|
||||
/* short tgp_size = */ WM * WN * 32>;
|
||||
|
||||
QBlockLoader loader_q(
|
||||
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
|
||||
KBlockLoader loader_k(
|
||||
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||
|
||||
// MMAFrag size
|
||||
constexpr short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
static_assert(
|
||||
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp
|
||||
constexpr int TK =
|
||||
BK / kFragSize; // KV sequence frags (all warps load the same frags)
|
||||
constexpr int TD =
|
||||
BD / kFragSize; // HeadDim frags (all warps load the same frags)
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
||||
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
||||
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
|
||||
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
||||
|
||||
Otile.clear();
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
short sm = simd_coord.y;
|
||||
short sn = simd_coord.x;
|
||||
short tm = kFragSize * TQ * simd_group_id;
|
||||
|
||||
short Qs_offset = (tm + sm) * LDQ_tgp + sn;
|
||||
short Ks_offset = sm * LDK_tgp + sn;
|
||||
short Vs_offset = sm * LDV_tgp + sn;
|
||||
|
||||
constexpr int Qs_tile_stride = kFragSize;
|
||||
constexpr int Ks_tile_stride = kFragSize * LDK_tgp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
loader_q.load_unsafe();
|
||||
loader_q.apply_inplace_op(ts);
|
||||
|
||||
for (int kb = 0; kb < params->NK; kb++) {
|
||||
// Load Q and K blocks and apply scale
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_k.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do S = Q @ K.T
|
||||
Stile.clear();
|
||||
|
||||
for (short dd = 0; dd < TD; dd++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
|
||||
&Qs[Qs_offset + dd * Qs_tile_stride]);
|
||||
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
|
||||
&Ks[Ks_offset + dd * Ks_tile_stride]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Stile, Qtile, Ktile, Stile);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Load V blocks
|
||||
loader_v.load_unsafe();
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Do O = S @ V
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||
|
||||
// Prepare for next iteration
|
||||
// loader_q.next();
|
||||
loader_k.next();
|
||||
loader_v.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * params->O_strides[2] + sn;
|
||||
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
// #include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
@ -21,11 +21,12 @@
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \
|
||||
instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \
|
||||
instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \
|
||||
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
// instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
// clang-format on
|
||||
|
@ -133,5 +133,132 @@ struct BlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <int R, int C>
|
||||
struct CShape {
|
||||
STEEL_CONST int kRows = R;
|
||||
STEEL_CONST int kCols = C;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short kDstStrRow,
|
||||
short kDstStrCol,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoaderT {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoaderT(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] =
|
||||
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
@ -358,12 +358,12 @@ METAL_FUNC void tile_matmad(
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
for (short k = 0; k < K; ++k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
|
@ -385,9 +385,9 @@ struct BlockMMA {
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
STEEL_CONST short TM = BM / (kFragSize * WM);
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
STEEL_CONST short TN = BN / (kFragSize * WN);
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
|
@ -26,10 +26,10 @@ void sdpa_full_self_attention_metal(
|
||||
array& o) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
int wm = 4;
|
||||
int wn = 1;
|
||||
|
||||
int bq = 16;
|
||||
int bq = 32;
|
||||
int bk = 16;
|
||||
int bd = 64;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user