mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
Fused GEMM (#1123)
* Basic gemm working * Update addmm * Clear out steel_gemm and steel_addmm kernels * Fuse and clear out gather gemm * Update objc releases
This commit is contained in:
parent
631dfbe673
commit
358e1fd6ab
@ -357,7 +357,6 @@ MTL::Function* Device::get_function_(
|
||||
}
|
||||
|
||||
mtl_func_consts->release();
|
||||
desc->release();
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
@ -526,11 +525,13 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
@ -1,136 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||
"_K_" #kname)]] [[kernel]] void \
|
||||
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -1,340 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformAdd<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
|
||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
A_bstrides,
|
||||
B_bstrides,
|
||||
C_bstrides,
|
||||
params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
C += batch_offsets.z;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(
|
||||
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
if (tgp_bm == BM && tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(
|
||||
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D,
|
||||
params->ldd,
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned, \
|
||||
ep_name, \
|
||||
epilogue) \
|
||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
|
||||
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
|
||||
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
|
||||
addmm< \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
mn_aligned, \
|
||||
k_aligned, \
|
||||
float, \
|
||||
epilogue<itype, float>>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
const device itype* C [[buffer(2)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -0,0 +1,468 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
// Find block
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
// Exit early if out of bounds
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
|
||||
// Handle gather
|
||||
if (do_gather) {
|
||||
// Read indices
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
|
||||
if (has_batch) {
|
||||
const constant size_t* indx_A_bstrides = batch_strides;
|
||||
const constant size_t* indx_B_bstrides =
|
||||
batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
indx_A_bstrides,
|
||||
indx_B_bstrides,
|
||||
params->batch_ndim);
|
||||
indx_A = lhs_indices[indx_offsets.x];
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
indx_C = C_indices[indx_offset_C];
|
||||
}
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
|
||||
if (use_out_source) {
|
||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||
}
|
||||
}
|
||||
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant size_t* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
if (use_out_source) {
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
}
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Prepare iterations
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (align_M && align_N) {
|
||||
// Do gemm
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
const int leftover_bk = 0;
|
||||
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
// Do gemm
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
} else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -1,168 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]],
|
||||
const constant int* batch_shape_A [[buffer(12)]],
|
||||
const constant size_t* batch_strides_A [[buffer(13)]],
|
||||
const constant int* batch_shape_B [[buffer(14)]],
|
||||
const constant size_t* batch_strides_B [[buffer(15)]],
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
uint32_t indx_A;
|
||||
uint32_t indx_B;
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
indx_A = lhs_indices[batch_offsets.x];
|
||||
indx_B = rhs_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
}
|
||||
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
|
||||
if (batch_ndim_A > 1) {
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
} else {
|
||||
A += indx_A * batch_strides_A[0];
|
||||
}
|
||||
|
||||
if (batch_ndim_B > 1) {
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
} else {
|
||||
B += indx_B * batch_strides_B[0];
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]], \
|
||||
const constant int* batch_shape_A [[buffer(12)]], \
|
||||
const constant size_t* batch_strides_A [[buffer(13)]], \
|
||||
const constant int* batch_shape_B [[buffer(14)]], \
|
||||
const constant size_t* batch_strides_B [[buffer(15)]], \
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -198,6 +198,73 @@ struct BlockMMA {
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
// Read C
|
||||
U c_elems[2] = {0};
|
||||
|
||||
if ((j * TN_stride + 1) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
c_elems[1] = C[offset_c + fdc];
|
||||
} else if ((j * TN_stride) < dst_tile_dims.x) {
|
||||
c_elems[0] = C[offset_c];
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
|
||||
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
|
@ -26,6 +26,10 @@ template <typename OutT, typename InT>
|
||||
struct TransformAdd {
|
||||
TransformAdd(const float, const float) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||
return static_cast<OutT>(x) + c;
|
||||
}
|
||||
@ -39,6 +43,10 @@ struct TransformAxpby {
|
||||
TransformAxpby(const float alpha_, const float beta_)
|
||||
: alpha(alpha_), beta(beta_) {}
|
||||
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
}
|
||||
|
@ -298,16 +298,46 @@ void steel_matmul_conv_groups(
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = false;
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
@ -351,10 +381,8 @@ void steel_matmul_conv_groups(
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -521,16 +549,46 @@ void steel_matmul(
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
@ -576,10 +634,8 @@ void steel_matmul(
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -957,13 +1013,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(
|
||||
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
|
||||
set_vector_bytes(compute_encoder, C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
@ -1089,18 +1142,48 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = (batch_shape.size() > 1);
|
||||
const bool use_out_source = true;
|
||||
const bool do_axpby = !(alpha_ == 1. && beta_ == 1.);
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = false;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -1155,10 +1238,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -1593,16 +1674,46 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const bool has_batch = batch_ndim > 1;
|
||||
const bool use_out_source = false;
|
||||
const bool do_axpby = false;
|
||||
const bool align_M = (M % bm) == 0;
|
||||
const bool align_N = (N % bn) == 0;
|
||||
const bool align_K = (K % bk) == 0;
|
||||
const bool do_gather = true;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
||||
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
||||
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
||||
{&align_M, MTL::DataType::DataTypeBool, 200},
|
||||
{&align_N, MTL::DataType::DataTypeBool, 201},
|
||||
{&align_K, MTL::DataType::DataTypeBool, 202},
|
||||
{&do_gather, MTL::DataType::DataTypeBool, 300},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
||||
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
||||
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
||||
<< "_align_M_" << (align_M ? 't' : 'n')
|
||||
<< "_align_N_" << (align_N ? 't' : 'n')
|
||||
<< "_align_K_" << (align_K ? 't' : 'n')
|
||||
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
@ -1650,11 +1761,19 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape_A, 12);
|
||||
set_vector_bytes(compute_encoder, batch_strides_A, 13);
|
||||
set_vector_bytes(compute_encoder, batch_shape_B, 14);
|
||||
set_vector_bytes(compute_encoder, batch_strides_B, 15);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
|
||||
std::vector operand_shape = batch_shape_A;
|
||||
operand_shape.insert(
|
||||
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
|
||||
|
||||
std::vector operand_strides = batch_strides_A;
|
||||
operand_strides.insert(
|
||||
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
|
||||
|
||||
operand_batch_ndim.push_back(0);
|
||||
|
||||
set_vector_bytes(compute_encoder, operand_shape, 13);
|
||||
set_vector_bytes(compute_encoder, operand_strides, 14);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user