Refactor split k axpby

This commit is contained in:
Jagrit Digani 2025-06-11 07:47:56 -07:00
parent 9fd8eb357c
commit c2f1c2a338

View File

@ -295,11 +295,13 @@ void steel_matmul_regular(
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
void steel_gemm_splitk( template <bool CHECK_AB = true>
void steel_gemm_splitk_axpby(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
const array& b, const array& b,
const array& c,
array& out, array& out,
int M, int M,
int N, int N,
@ -309,7 +311,9 @@ void steel_gemm_splitk(
int ldb, int ldb,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies) { std::vector<array>& copies,
float alpha = 1.0f,
float beta = 0.0f) {
using namespace mlx::steel; using namespace mlx::steel;
int _tm = M / 16; int _tm = M / 16;
@ -393,11 +397,21 @@ void steel_gemm_splitk(
// Do accum kernel // Do accum kernel
{ {
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split); type_to_name(C_split);
auto kernel = if (do_axpby) {
get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false); kernel_name = kernel_name + "_axbpy";
}
auto kernel = get_steel_gemm_splitk_accum_kernel(
/* metal::Device& d = */ d,
/* const std::string& kernel_name = */ kernel_name,
/* const array& in = */ C_split,
/* const array& out = */ out,
/* bool axbpy = */ do_axpby);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
@ -407,6 +421,17 @@ void steel_gemm_splitk(
compute_encoder.set_bytes(split_k_partition_stride, 3); compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4); compute_encoder.set_bytes(N, 4);
if (do_axpby) {
int ldc = c.strides()[c.ndim() - 2];
int fdc = c.strides()[c.ndim() - 1];
compute_encoder.set_input_array(c, 5);
compute_encoder.set_bytes(ldc, 6);
compute_encoder.set_bytes(fdc, 7);
compute_encoder.set_bytes(alpha, 8);
compute_encoder.set_bytes(beta, 9);
}
// Launch enough thread groups for each output // Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1); auto group_dims = get_block_dims(N, M, 1);
@ -416,6 +441,39 @@ void steel_gemm_splitk(
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }
inline void steel_gemm_splitk(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
return steel_gemm_splitk_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies);
}
void steel_matmul( void steel_matmul(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
@ -899,106 +957,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int _tk = K / 16; int _tk = K / 16;
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32; return steel_gemm_splitk_axpby(
int bn = N < 40 ? 16 : 32; /* const Stream& s = */ s,
int bk = 16; /* metal::Device& d = */ d,
int wm = 2, wn = 2; /* const array& a = */ a,
/* const array& b = */ b,
int split_k_partitions = /* const array& c = */ c,
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); /* array& out = */ out,
int split_k_partition_stride = M * N; /* int M = */ M,
int gemm_k_iterations = (K / bk) / split_k_partitions; /* int N = */ N,
int split_k_partition_size = gemm_k_iterations * bk; /* int K = */ K,
/* int batch_size_out = */ batch_size_out,
array C_split({split_k_partitions, M, N}, float32, nullptr, {}); /* int lda = */ lda,
C_split.set_data(allocator::malloc(C_split.nbytes())); /* int ldb = */ ldb,
copies.push_back(C_split); /* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
bool mn_aligned = M % bm == 0 && N % bn == 0; /* std::vector<array>& copies = */ copies,
bool k_aligned = K % bk == 0; /* float alpha = */ alpha_,
/* float beta = */ beta_);
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
split_k_partitions,
split_k_partition_stride,
split_k_partition_size,
gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split) + "_axbpy";
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true);
compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
compute_encoder.set_input_array(c, 5);
compute_encoder.set_bytes(ldc, 6);
compute_encoder.set_bytes(fdc, 7);
compute_encoder.set_bytes(alpha_, 8);
compute_encoder.set_bytes(beta_, 9);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
return;
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////