mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Refactor split k axpby
This commit is contained in:
parent
9fd8eb357c
commit
c2f1c2a338
@ -295,11 +295,13 @@ void steel_matmul_regular(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void steel_gemm_splitk(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_gemm_splitk_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
@ -309,7 +311,9 @@ void steel_gemm_splitk(
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& copies,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
int _tm = M / 16;
|
||||
@ -393,11 +397,21 @@ void steel_gemm_splitk(
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel =
|
||||
get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false);
|
||||
if (do_axpby) {
|
||||
kernel_name = kernel_name + "_axbpy";
|
||||
}
|
||||
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
/* metal::Device& d = */ d,
|
||||
/* const std::string& kernel_name = */ kernel_name,
|
||||
/* const array& in = */ C_split,
|
||||
/* const array& out = */ out,
|
||||
/* bool axbpy = */ do_axpby);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
@ -407,6 +421,17 @@ void steel_gemm_splitk(
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
|
||||
if (do_axpby) {
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
int fdc = c.strides()[c.ndim() - 1];
|
||||
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder.set_bytes(ldc, 6);
|
||||
compute_encoder.set_bytes(fdc, 7);
|
||||
compute_encoder.set_bytes(alpha, 8);
|
||||
compute_encoder.set_bytes(beta, 9);
|
||||
}
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
@ -416,6 +441,39 @@ void steel_gemm_splitk(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
inline void steel_gemm_splitk(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
return steel_gemm_splitk_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@ -899,106 +957,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int _tk = K / 16;
|
||||
|
||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||
int bm = M < 40 ? 16 : 32;
|
||||
int bn = N < 40 ? 16 : 32;
|
||||
int bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int split_k_partitions =
|
||||
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
||||
int split_k_partition_stride = M * N;
|
||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_splitk_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
a,
|
||||
C_split,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
GEMMSpiltKParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
split_k_partitions,
|
||||
split_k_partition_stride,
|
||||
split_k_partition_size,
|
||||
gemm_k_iterations};
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder.set_bytes(params, 3);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axbpy";
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder.set_input_array(C_split, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||
compute_encoder.set_bytes(N, 4);
|
||||
compute_encoder.set_input_array(c, 5);
|
||||
compute_encoder.set_bytes(ldc, 6);
|
||||
compute_encoder.set_bytes(fdc, 7);
|
||||
compute_encoder.set_bytes(alpha_, 8);
|
||||
compute_encoder.set_bytes(beta_, 9);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
return steel_gemm_splitk_axpby(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ c,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* float alpha = */ alpha_,
|
||||
/* float beta = */ beta_);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user