mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-08 20:31:13 +08:00
Rough INIT
This commit is contained in:
parent
0c5eea226b
commit
ad5b58b34e
@ -68,6 +68,22 @@ set(STEEL_HEADERS
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
set(STEEL_ATTN_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h
|
||||
steel/attn/attn.h
|
||||
steel/attn/loader.h
|
||||
steel/attn/mma.h
|
||||
steel/attn/params.h
|
||||
steel/attn/transforms.h
|
||||
steel/attn/kernels/steel_attention.h)
|
||||
|
||||
if(NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
@ -93,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_masked ${STEEL_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_splitk ${STEEL_HEADERS})
|
||||
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
|
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
296
mlx/backend/metal/kernels/steel/attn/attn.h
Normal file
@ -0,0 +1,296 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel class
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||
struct LoopAlignment {};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct GEMMKernel {
|
||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
STEEL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
STEEL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||
|
||||
using loader_a_t = BlockLoader<
|
||||
T,
|
||||
transpose_a ? BK : BM,
|
||||
transpose_a ? BM : BK,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
!transpose_a,
|
||||
tgp_size>;
|
||||
using loader_b_t = BlockLoader<
|
||||
T,
|
||||
transpose_b ? BN : BK,
|
||||
transpose_b ? BK : BN,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
transpose_b,
|
||||
tgp_size>;
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||
static METAL_FUNC void gemm_loop(
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
const int gemm_k_iterations,
|
||||
thread loader_a_t& loader_a,
|
||||
thread loader_b_t& loader_b,
|
||||
thread mma_t& mma_op,
|
||||
thread const short& tgp_bm,
|
||||
thread const short& tgp_bn,
|
||||
thread const short& lbk,
|
||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned_) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* D [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_loop<true, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_loop<false, true, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_loop<true, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
gemm_loop<false, false, K_aligned>(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -0,0 +1,65 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BQ,
|
||||
int BK,
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
|
||||
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
tidl.y * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
tidl.y * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D;
|
||||
i += WM * WN * 32) {
|
||||
int r = i / params->D;
|
||||
int c = i % params->D;
|
||||
|
||||
O[params->O_strides[2] * r + c] = T(0);
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
|
||||
const device dtype* Q [[buffer(0)]], \
|
||||
const device dtype* K [[buffer(1)]], \
|
||||
const device dtype* V [[buffer(2)]], \
|
||||
device dtype* O [[buffer(3)]],\
|
||||
const constant AttnParams* params [[buffer(4)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype) \
|
||||
instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \
|
||||
|
||||
|
||||
instantiate_attn_shapes_helper(float16, half);
|
||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_attn_shapes_helper(float32, float);
|
||||
// clang-format on
|
137
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
137
mlx/backend/metal/kernels/steel/attn/loader.h
Normal file
@ -0,0 +1,137 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short alignment = 1,
|
||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||
short TCOLS = BCOLS / n_reads,
|
||||
short TROWS = tgp_size / TCOLS>
|
||||
struct BlockLoader {
|
||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||
STEEL_CONST short vec_size = n_reads;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||
uint8_t v[sizeof(T) * vec_size];
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockLoader(
|
||||
const device T* src_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj) {}
|
||||
|
||||
/* Apply operation to threadgroup without bound checking */
|
||||
template <typename UnaryOp>
|
||||
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||
*((const device ReadVector*)(&src[i * src_ld]));
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
// Make sure tmp_idx only contains valid indices
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||
}
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
668
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
668
mlx/backend/metal/kernels/steel/attn/mma.h
Normal file
@ -0,0 +1,668 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename RInt, typename CInt>
|
||||
struct Shape2D {
|
||||
RInt r;
|
||||
CInt c;
|
||||
|
||||
Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {}
|
||||
};
|
||||
|
||||
template <typename Shape, typename Layout>
|
||||
struct Layout2D {
|
||||
Shape shape;
|
||||
Layout layout;
|
||||
};
|
||||
|
||||
template <typename T, int kFragRows_, int kFragCols_>
|
||||
struct BaseMMAFrag {
|
||||
static_assert(
|
||||
kFragRows_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
static_assert(
|
||||
kFragCols_ == 8,
|
||||
"Only 8 x 8 fragment matrices are currently supported");
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BaseMMAFrag<T, 8, 8> {
|
||||
STEEL_CONST int kFragRows = 8;
|
||||
STEEL_CONST int kFragCols = 8;
|
||||
|
||||
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
||||
|
||||
STEEL_CONST int kElemRows = 1;
|
||||
STEEL_CONST int kElemCols = 2;
|
||||
|
||||
static_assert(
|
||||
kElemRows * kElemCols == kElemsPerFrag,
|
||||
"MMAFrag shape is not consistent with MMAFrag size");
|
||||
|
||||
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
||||
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
||||
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
return short2{fn, fm};
|
||||
}
|
||||
|
||||
template <typename SrcPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename SrcPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void load_safe(
|
||||
thread frag_type& dst,
|
||||
SrcPtrType src,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[i * kElemCols + j] =
|
||||
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
||||
} else {
|
||||
dst[i * kElemCols + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstPtrType, typename StrX, typename StrY>
|
||||
METAL_FUNC static constexpr void
|
||||
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename DstPtrType,
|
||||
typename StrX,
|
||||
typename StrY,
|
||||
typename LimX,
|
||||
typename LimY,
|
||||
typename OffX,
|
||||
typename OffY>
|
||||
METAL_FUNC static constexpr void store_safe(
|
||||
const thread frag_type& src,
|
||||
DstPtrType dst,
|
||||
StrX str_x,
|
||||
StrY str_y,
|
||||
LimX lim_x,
|
||||
LimY lim_y,
|
||||
OffX off_x = Int<0>{},
|
||||
OffY off_y = Int<0>{}) {
|
||||
using U = pointer_element_t<DstPtrType>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemRows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kElemCols; j++) {
|
||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
||||
static_cast<U>(src[i * kElemCols + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int kTileRows_,
|
||||
int kTileCols_,
|
||||
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
||||
struct MMATile {
|
||||
using MMAFrag_t = MMAFrag_;
|
||||
using elem_type = T;
|
||||
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
||||
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
||||
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
||||
|
||||
STEEL_CONST int kTileRows = kTileRows_;
|
||||
STEEL_CONST int kTileCols = kTileCols_;
|
||||
|
||||
STEEL_CONST int kRows = kTileRows * kFragRows;
|
||||
STEEL_CONST int kCols = kTileCols * kFragCols;
|
||||
|
||||
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
||||
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
||||
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
METAL_FUNC constexpr void clear() {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kNumFrags; ++i) {
|
||||
val_frags[i] = frag_type(0);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC constexpr const thread frag_type& frag_at(
|
||||
const short i,
|
||||
const short j) const {
|
||||
return val_frags[i * kTileCols + j];
|
||||
}
|
||||
|
||||
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
||||
mat_type val_mat;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
||||
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
||||
}
|
||||
return val_mat;
|
||||
}
|
||||
|
||||
METAL_FUNC thread elem_type* elems() {
|
||||
return reinterpret_cast<thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
METAL_FUNC const thread elem_type* elems() const {
|
||||
return reinterpret_cast<const thread elem_type*>(val_frags);
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void load(const threadgroup U* src) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
src[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
||||
METAL_FUNC void store(threadgroup U* dst) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(
|
||||
dst[(i * kFragRows) * w_x * str_x +
|
||||
(j * kFragCols) * w_y * str_y]),
|
||||
Int<str_x>{},
|
||||
Int<str_y>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void load(const device U* src, const int ld) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load(
|
||||
frag_at(i, j),
|
||||
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void store(device U* dst, const int ld) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store(
|
||||
frag_at(i, j),
|
||||
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
||||
ld,
|
||||
Int<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::load_safe(
|
||||
frag_at(i, j),
|
||||
src,
|
||||
ld,
|
||||
Int<1>{},
|
||||
src_tile_dims.y,
|
||||
src_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int w_x, int w_y>
|
||||
METAL_FUNC void
|
||||
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kTileRows; ++i) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kTileCols; ++j) {
|
||||
MMAFrag_t::store_safe(
|
||||
frag_at(i, j),
|
||||
dst,
|
||||
ld,
|
||||
Int<1>{},
|
||||
dst_tile_dims.y,
|
||||
dst_tile_dims.x,
|
||||
(i * kFragRows) * w_x,
|
||||
(j * kFragCols) * w_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<U, AccumType>>
|
||||
struct BlockMMA {
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / TM_stride;
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / TN_stride;
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// Simdgroup matrices
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
||||
|
||||
// Offsets within threadgroup
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
short As_offset;
|
||||
short Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
tile_matmad(Ctile, Atile, Btile, Ctile);
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As += tile_stride_a;
|
||||
Bs += tile_stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
Ctile.template store<U, WM, WN>(D, ldd);
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
||||
}
|
||||
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename UnaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
||||
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Read C
|
||||
U c_elems[kelems] = {0};
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
c_elems[k] = C[offset_c + k * fdc];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = Ctile.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[offset_d + k] =
|
||||
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
31
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
31
mlx/backend/metal/kernels/steel/attn/params.h
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Attn param classes
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct AttnParams {
|
||||
int B; ///< Batch Size
|
||||
int H; ///< Heads
|
||||
int L; ///< Sequence Length
|
||||
int D; ///< Head Dim
|
||||
|
||||
int gqa_factor; ///< Group Query factor
|
||||
float scale; ///< Attention scale
|
||||
|
||||
int NQ; ///< Number of query blocks
|
||||
int NK; ///< Number of key/value blocks
|
||||
|
||||
size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
|
||||
size_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
|
||||
size_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
|
||||
size_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
71
mlx/backend/metal/kernels/steel/attn/transforms.h
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms and Epilogues
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformAxpby {
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
struct BlockSwizzle {
|
||||
static METAL_FUNC int2
|
||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||
const int tid_x = (tid.x) >> swizzle_log;
|
||||
const int tid_y =
|
||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||
return int2(tid_x, tid_y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -7,6 +7,9 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -19,122 +22,57 @@ void sdpa_full_self_attention_metal(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
const float alpha,
|
||||
array& out) {
|
||||
std::ostringstream kname_self_attention;
|
||||
kname_self_attention << "steel_gemm_attention_";
|
||||
const float scale,
|
||||
array& o) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
constexpr const int bm = 16;
|
||||
constexpr const int bn = 16;
|
||||
const int bk = q.shape(-1); // already forced to be 64 or 128
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
|
||||
if (bk != 64 && bk != 128) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128");
|
||||
}
|
||||
int bq = 16;
|
||||
int bk = 16;
|
||||
int bd = 64;
|
||||
|
||||
constexpr const int wm = 2;
|
||||
constexpr const int wn = 2;
|
||||
|
||||
std::string delimiter = "_";
|
||||
|
||||
kname_self_attention << "bm_" + std::to_string(bm) + delimiter;
|
||||
kname_self_attention << "bn_" + std::to_string(bn) + delimiter;
|
||||
kname_self_attention << "bk_" + std::to_string(bk) + delimiter;
|
||||
|
||||
for (const auto& arr : {k, v, out}) {
|
||||
if (arr.dtype() != q.dtype()) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
|
||||
}
|
||||
}
|
||||
|
||||
if (q.dtype() == float32) {
|
||||
kname_self_attention << "itype" + delimiter + "float";
|
||||
} else if (q.dtype() == float16) {
|
||||
kname_self_attention << "itype" + delimiter + "half";
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
|
||||
}
|
||||
std::ostringstream kname;
|
||||
kname << "steel_attention_" << type_to_name(q) << "_bq" << bq << "_bk" << bk
|
||||
<< "_bd" << bd << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_self_attention.str());
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
uint hidden_dim = q.shape(-1);
|
||||
uint qseq = q.shape(-2);
|
||||
uint qheads = q.shape(-3);
|
||||
int B = q.shape(0);
|
||||
int H = q.shape(1);
|
||||
int L = q.shape(2);
|
||||
int D = q.shape(3);
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
|
||||
const uint64_t KV_sequence_length = k.shape(-2);
|
||||
const uint query_sequence_length = q.shape(-2);
|
||||
const uint n_q_heads = q.shape(1);
|
||||
const uint n_kv_heads = k.shape(1);
|
||||
int NQ = (L + bq - 1) / bq;
|
||||
int NK = (L + bk - 1) / bk;
|
||||
|
||||
const int M = q.shape(-2);
|
||||
const int N = M;
|
||||
const int K = q.shape(-1);
|
||||
const size_t batch_size_out = q.shape(0) * q.shape(1);
|
||||
AttnParams params{
|
||||
/* int B = */ B,
|
||||
/* int H = */ H,
|
||||
/* int L = */ L,
|
||||
/* int D = */ D,
|
||||
/* int gqa_factor = */ gqa_factor,
|
||||
/* float scale = */ scale,
|
||||
|
||||
const std::vector<int> batch_shape = {q.shape(0) * q.shape(1)};
|
||||
const int dk = q.shape(-1);
|
||||
const int ldq = dk;
|
||||
const int ldk = dk;
|
||||
const int ldv = dk;
|
||||
const int lds = bn;
|
||||
const int ldo = dk;
|
||||
/* int NQ = */ NQ,
|
||||
/* int NK = */ NK,
|
||||
|
||||
int tn = 1;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
const int batch_stride_q = dk * query_sequence_length;
|
||||
const int batch_stride_k = dk * query_sequence_length;
|
||||
const int batch_stride_v = dk * query_sequence_length;
|
||||
const int batch_stride_o = dk * query_sequence_length;
|
||||
const int swizzle_log = 0;
|
||||
const int gemm_n_iterations_aligned = (N + bn - 1) / bn;
|
||||
const int gemm_k_iterations_aligned = (K + bk - 1) / bk;
|
||||
const int gemm_sv_m_block_iterations = (M + bm - 1) / bm;
|
||||
const int batch_ndim = int(batch_shape.size());
|
||||
|
||||
MLXFastAttentionParams params{
|
||||
(int)M,
|
||||
(int)N,
|
||||
(int)K,
|
||||
ldq,
|
||||
ldk,
|
||||
ldv,
|
||||
lds,
|
||||
ldo,
|
||||
tn,
|
||||
tm,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o,
|
||||
swizzle_log,
|
||||
gemm_n_iterations_aligned,
|
||||
gemm_k_iterations_aligned,
|
||||
gemm_sv_m_block_iterations,
|
||||
batch_ndim,
|
||||
alpha};
|
||||
|
||||
const std::vector<size_t> batch_strides = {
|
||||
(size_t)batch_stride_q,
|
||||
(size_t)batch_stride_k,
|
||||
(size_t)batch_stride_v,
|
||||
(size_t)batch_stride_o};
|
||||
/* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||
/* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||
/* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||
/* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_output_array(o, 3);
|
||||
compute_encoder.set_bytes(params, 4);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out);
|
||||
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
@ -600,7 +600,7 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1e6;
|
||||
int threshold = 1024; // TODO: Fix after dev
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user