Rough INIT

This commit is contained in:
Jagrit Digani 2024-11-14 16:27:13 -08:00
parent 0c5eea226b
commit ad5b58b34e
10 changed files with 1354 additions and 100 deletions

View File

@ -68,6 +68,22 @@ set(STEEL_HEADERS
steel/utils/type_traits.h
steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)
@ -93,6 +109,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
build_kernel(gemv_masked steel/utils.h)
endif()

View File

@ -0,0 +1,296 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,65 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
K += tidl.z * params->K_strides[0] + // Batch
tidl.y * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
tidl.y * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D;
i += WM * WN * 32) {
int r = i / params->D;
int c = i % params->D;
O[params->O_strides[2] * r + c] = T(0);
}
}

View File

@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@ -0,0 +1,137 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,668 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename RInt, typename CInt>
struct Shape2D {
RInt r;
CInt c;
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
};
template <typename Shape, typename Layout>
struct Layout2D {
Shape shape;
Layout layout;
};
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
thread frag_type& B,
thread frag_type& C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread mat_type& A,
thread mat_type& B,
thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D,
thread MMATile<U, M, K>& A,
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
template <typename UnaryEpilogue>
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// Attn param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct AttnParams {
int B; ///< Batch Size
int H; ///< Heads
int L; ///< Sequence Length
int D; ///< Head Dim
int gqa_factor; ///< Group Query factor
float scale; ///< Attention scale
int NQ; ///< Number of query blocks
int NK; ///< Number of key/value blocks
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,71 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@ -7,6 +7,9 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"
@ -19,122 +22,57 @@ void sdpa_full_self_attention_metal(
const array& q,
const array& k,
const array& v,
const float alpha,
array& out) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
const float scale,
array& o) {
using namespace mlx::steel;
constexpr const int bm = 16;
constexpr const int bn = 16;
const int bk = q.shape(-1); // already forced to be 64 or 128
int wm = 2;
int wn = 2;
if (bk != 64 && bk != 128) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
}
int bq = 16;
int bk = 16;
int bd = 64;
constexpr const int wm = 2;
constexpr const int wn = 2;
std::string delimiter = "_";
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
if (q.dtype() == float32) {
kname_self_attention << "itype" + delimiter + "float";
} else if (q.dtype() == float16) {
kname_self_attention << "itype" + delimiter + "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::ostringstream kname;
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
<< "_bd" << bd << "_wm" << wm << "_wn" << wn;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_self_attention.str());
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
uint hidden_dim = q.shape(-1);
uint qseq = q.shape(-2);
uint qheads = q.shape(-3);
int B = q.shape(0);
int H = q.shape(1);
int L = q.shape(2);
int D = q.shape(3);
int gqa_factor = q.shape(1) / k.shape(1);
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
int NQ = (L + bq - 1) / bq;
int NK = (L + bk - 1) / bk;
const int M = q.shape(-2);
const int N = M;
const int K = q.shape(-1);
const size_t batch_size_out = q.shape(0) * q.shape(1);
AttnParams params{
/* int B = */ B,
/* int H = */ H,
/* int L = */ L,
/* int D = */ D,
/* int gqa_factor = */ gqa_factor,
/* float scale = */ scale,
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
const int dk = q.shape(-1);
const int ldq = dk;
const int ldk = dk;
const int ldv = dk;
const int lds = bn;
const int ldo = dk;
/* int NQ = */ NQ,
/* int NK = */ NK,
int tn = 1;
int tm = (M + bm - 1) / bm;
const int batch_stride_q = dk * query_sequence_length;
const int batch_stride_k = dk * query_sequence_length;
const int batch_stride_v = dk * query_sequence_length;
const int batch_stride_o = dk * query_sequence_length;
const int swizzle_log = 0;
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
const int batch_ndim = int(batch_shape.size());
MLXFastAttentionParams params{
(int)M,
(int)N,
(int)K,
ldq,
ldk,
ldv,
lds,
ldo,
tn,
tm,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o,
swizzle_log,
gemm_n_iterations_aligned,
gemm_k_iterations_aligned,
gemm_sv_m_block_iterations,
batch_ndim,
alpha};
const std::vector<size_t> batch_strides = {
(size_t)batch_stride_q,
(size_t)batch_stride_k,
(size_t)batch_stride_v,
(size_t)batch_stride_o};
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_output_array(o, 3);
compute_encoder.set_bytes(params, 4);
compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7);
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
MTL::Size grid_dims = MTL::Size(NQ, H, B);
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);

View File

@ -600,7 +600,7 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/
int threshold = 1e6;
int threshold = 1024; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}