Refactor splitk step 1

This commit is contained in:
Jagrit Digani 2025-06-11 07:29:36 -07:00
parent f828c5b5ae
commit d192587cdf

View File

@ -295,6 +295,127 @@ void steel_matmul_regular(
d.add_temporaries(std::move(copies), s.index);
}
void steel_gemm_splitk(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
using namespace mlx::steel;
int _tm = M / 16;
int _tn = N / 16;
int _tk = K / 16;
int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32;
int bk = 16;
int wm = 2, wn = 2;
int split_k_partitions = _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
int split_k_partition_stride = M * N;
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname;
// clang-format off
kname << "steel_gemm_splitk_"
<< (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(C_split)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_splitk_kernel(
/* metal::Device& d = */ d,
/* const std::string& kernel_name = */ kname.str(),
/* const array& in = */ a,
/* const array& out = */ C_split,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* int bm = */ bm,
/* int bn = */ bn,
/* int bk = */ bk,
/* int wm = */ wm,
/* int wn = */ wn,
/* bool mn_aligned = */ mn_aligned,
/* bool k_aligned = */ k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldc = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int split_k_partitions = */ split_k_partitions,
/* const int split_k_partition_stride = */ split_k_partition_stride,
/* const int split_k_partition_size = */ split_k_partition_size,
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel =
get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false);
compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
}
void steel_matmul(
const Stream& s,
metal::Device& d,
@ -346,99 +467,21 @@ void steel_matmul(
int _tk = K / 16;
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32;
int bk = 16;
int wm = 2, wn = 2;
int split_k_partitions =
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
int split_k_partition_stride = M * N;
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder.set_compute_pipeline_state(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldc = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int split_k_partitions = */ split_k_partitions,
/* const int split_k_partition_stride = */ split_k_partition_stride,
/* const int split_k_partition_size = */ split_k_partition_size,
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(C_split, 2);
compute_encoder.set_bytes(params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false);
compute_encoder.set_compute_pipeline_state(kernel);
// Set the arguments for the kernel
compute_encoder.set_input_array(C_split, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(split_k_partitions, 2);
compute_encoder.set_bytes(split_k_partition_stride, 3);
compute_encoder.set_bytes(N, 4);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
return;
return steel_gemm_splitk(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies);
}
/////////////////////////////////////////////////////////////////////////////
@ -447,27 +490,27 @@ void steel_matmul(
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
steel_matmul_regular(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
lda,
ldb,
N,
transpose_a,
transpose_b,
std::move(batch_shape),
std::move(batch_strides),
A_batch_stride.back(),
B_batch_stride.back(),
matrix_stride_out,
copies);
return steel_matmul_regular(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ N,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides batch_strides = */ std::move(batch_strides),
/* int64_t A_batch_stride = */ A_batch_stride.back(),
/* int64_t B_batch_stride = */ B_batch_stride.back(),
/* int64_t matrix_stride_out = */ matrix_stride_out,
/* std::vector<array>& copies = */ copies);
}
template <bool CHECK_AB = true>
@ -715,9 +758,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides A_batch_stride = */ std::move(A_batch_stride),
/* Strides B_batch_stride = */ std::move(B_batch_stride));
}
/////////////////////////////////////////////////////////////////////////////
@ -738,9 +781,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
/* Shape batch_shape = */ std::move(batch_shape),
/* Strides A_batch_stride = */ std::move(A_batch_stride),
/* Strides B_batch_stride = */ std::move(B_batch_stride));
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {