Fused GEMM (#1123)

* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
This commit is contained in:
Jagrit Digani 2024-05-15 10:30:41 -07:00 committed by GitHub
parent 631dfbe673
commit 358e1fd6ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 709 additions and 690 deletions

View File

@ -357,7 +357,6 @@ MTL::Function* Device::get_function_(
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
@ -526,11 +525,13 @@ MTL::ComputePipelineState* Device::get_kernel(
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}

View File

@ -1,136 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname)]] [[kernel]] void \
gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -1,340 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void addmm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned,
AccumType,
Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
A_bstrides,
B_bstrides,
C_bstrides,
params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += addmm_params->batch_stride_c * tid.z;
}
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
int leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(
D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D,
params->ldd,
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
ep_name, \
epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm \
"_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname \
"_K_" #kname "_" #ep_name)]] [[kernel]] void \
addmm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
float, \
epilogue<itype, float>>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
const device itype* C [[buffer(2)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) // clang-format on
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -0,0 +1,468 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@ -1,168 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10)]],
const constant uint32_t* rhs_indices [[buffer(11)]],
const constant int* batch_shape_A [[buffer(12)]],
const constant size_t* batch_strides_A [[buffer(13)]],
const constant int* batch_shape_B [[buffer(14)]],
const constant size_t* batch_strides_B [[buffer(15)]],
const constant int2& operand_batch_ndim [[buffer(16)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
uint32_t indx_A;
uint32_t indx_B;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
indx_A = lhs_indices[batch_offsets.x];
indx_B = rhs_indices[batch_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
int batch_ndim_A = operand_batch_ndim.x;
int batch_ndim_B = operand_batch_ndim.y;
if (batch_ndim_A > 1) {
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
} else {
A += indx_A * batch_strides_A[0];
}
if (batch_ndim_B > 1) {
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
} else {
B += indx_B * batch_strides_B[0];
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10)]], \
const constant uint32_t* rhs_indices [[buffer(11)]], \
const constant int* batch_shape_A [[buffer(12)]], \
const constant size_t* batch_strides_A [[buffer(13)]], \
const constant int* batch_shape_B [[buffer(14)]], \
const constant size_t* batch_strides_B [[buffer(15)]], \
const constant int2& operand_batch_ndim [[buffer(16)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -198,6 +198,73 @@ struct BlockMMA {
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Read C
U c_elems[2] = {0};
if ((j * TN_stride + 1) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
c_elems[1] = C[offset_c + fdc];
} else if ((j * TN_stride) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
}
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,

View File

@ -26,6 +26,10 @@ template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
@ -39,6 +43,10 @@ struct TransformAxpby {
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}

View File

@ -298,16 +298,46 @@ void steel_matmul_conv_groups(
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = false;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
@ -351,10 +381,8 @@ void steel_matmul_conv_groups(
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
@ -521,16 +549,46 @@ void steel_matmul(
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
@ -576,10 +634,8 @@ void steel_matmul(
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
@ -957,13 +1013,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->setBytes(
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
set_vector_bytes(compute_encoder, C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
@ -1089,18 +1142,48 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = true;
const bool do_axpby = !(alpha_ == 1. && beta_ == 1.);
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
@ -1155,10 +1238,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
@ -1593,16 +1674,46 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare kernel name
std::ostringstream kname;
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = true;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
@ -1650,11 +1761,19 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
set_vector_bytes(compute_encoder, batch_shape_A, 12);
set_vector_bytes(compute_encoder, batch_strides_A, 13);
set_vector_bytes(compute_encoder, batch_shape_B, 14);
set_vector_bytes(compute_encoder, batch_strides_B, 15);
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
std::vector operand_shape = batch_shape_A;
operand_shape.insert(
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
set_vector_bytes(compute_encoder, operand_shape, 13);
set_vector_bytes(compute_encoder, operand_strides, 14);
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);