mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Refactor split k axpby
This commit is contained in:
parent
9fd8eb357c
commit
c2f1c2a338
@ -295,11 +295,13 @@ void steel_matmul_regular(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
void steel_gemm_splitk(
|
template <bool CHECK_AB = true>
|
||||||
|
void steel_gemm_splitk_axpby(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
@ -309,7 +311,9 @@ void steel_gemm_splitk(
|
|||||||
int ldb,
|
int ldb,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
std::vector<array>& copies) {
|
std::vector<array>& copies,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
int _tm = M / 16;
|
int _tm = M / 16;
|
||||||
@ -393,11 +397,21 @@ void steel_gemm_splitk(
|
|||||||
|
|
||||||
// Do accum kernel
|
// Do accum kernel
|
||||||
{
|
{
|
||||||
|
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||||
|
|
||||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||||
type_to_name(C_split);
|
type_to_name(C_split);
|
||||||
|
|
||||||
auto kernel =
|
if (do_axpby) {
|
||||||
get_steel_gemm_splitk_accum_kernel(d, kernel_name, C_split, out, false);
|
kernel_name = kernel_name + "_axbpy";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const std::string& kernel_name = */ kernel_name,
|
||||||
|
/* const array& in = */ C_split,
|
||||||
|
/* const array& out = */ out,
|
||||||
|
/* bool axbpy = */ do_axpby);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
// Set the arguments for the kernel
|
||||||
@ -407,6 +421,17 @@ void steel_gemm_splitk(
|
|||||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
||||||
compute_encoder.set_bytes(N, 4);
|
compute_encoder.set_bytes(N, 4);
|
||||||
|
|
||||||
|
if (do_axpby) {
|
||||||
|
int ldc = c.strides()[c.ndim() - 2];
|
||||||
|
int fdc = c.strides()[c.ndim() - 1];
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(c, 5);
|
||||||
|
compute_encoder.set_bytes(ldc, 6);
|
||||||
|
compute_encoder.set_bytes(fdc, 7);
|
||||||
|
compute_encoder.set_bytes(alpha, 8);
|
||||||
|
compute_encoder.set_bytes(beta, 9);
|
||||||
|
}
|
||||||
|
|
||||||
// Launch enough thread groups for each output
|
// Launch enough thread groups for each output
|
||||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||||
auto group_dims = get_block_dims(N, M, 1);
|
auto group_dims = get_block_dims(N, M, 1);
|
||||||
@ -416,6 +441,39 @@ void steel_gemm_splitk(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void steel_gemm_splitk(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
return steel_gemm_splitk_axpby<false>(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ b,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies);
|
||||||
|
}
|
||||||
|
|
||||||
void steel_matmul(
|
void steel_matmul(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@ -899,106 +957,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int _tk = K / 16;
|
int _tk = K / 16;
|
||||||
|
|
||||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||||
int bm = M < 40 ? 16 : 32;
|
return steel_gemm_splitk_axpby(
|
||||||
int bn = N < 40 ? 16 : 32;
|
/* const Stream& s = */ s,
|
||||||
int bk = 16;
|
/* metal::Device& d = */ d,
|
||||||
int wm = 2, wn = 2;
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
int split_k_partitions =
|
/* const array& c = */ c,
|
||||||
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
/* array& out = */ out,
|
||||||
int split_k_partition_stride = M * N;
|
/* int M = */ M,
|
||||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
/* int N = */ N,
|
||||||
int split_k_partition_size = gemm_k_iterations * bk;
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
/* int lda = */ lda,
|
||||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
/* int ldb = */ ldb,
|
||||||
copies.push_back(C_split);
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
/* std::vector<array>& copies = */ copies,
|
||||||
bool k_aligned = K % bk == 0;
|
/* float alpha = */ alpha_,
|
||||||
|
/* float beta = */ beta_);
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
|
||||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
||||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
|
||||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
|
||||||
|
|
||||||
// Encode and dispatch gemm kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = get_steel_gemm_splitk_kernel(
|
|
||||||
d,
|
|
||||||
kname.str(),
|
|
||||||
a,
|
|
||||||
C_split,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
bm,
|
|
||||||
bn,
|
|
||||||
bk,
|
|
||||||
wm,
|
|
||||||
wn,
|
|
||||||
mn_aligned,
|
|
||||||
k_aligned);
|
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int tn = (N + bn - 1) / bn;
|
|
||||||
int tm = (M + bm - 1) / bm;
|
|
||||||
|
|
||||||
GEMMSpiltKParams params{
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
lda,
|
|
||||||
ldb,
|
|
||||||
N,
|
|
||||||
tn,
|
|
||||||
tm,
|
|
||||||
split_k_partitions,
|
|
||||||
split_k_partition_stride,
|
|
||||||
split_k_partition_size,
|
|
||||||
gemm_k_iterations};
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
||||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(a, 0);
|
|
||||||
compute_encoder.set_input_array(b, 1);
|
|
||||||
compute_encoder.set_output_array(C_split, 2);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(params, 3);
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
|
|
||||||
// Do accum kernel
|
|
||||||
{
|
|
||||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
|
||||||
type_to_name(C_split) + "_axbpy";
|
|
||||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
|
||||||
d, kernel_name, C_split, out, true);
|
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
|
||||||
compute_encoder.set_input_array(C_split, 0);
|
|
||||||
compute_encoder.set_output_array(out, 1);
|
|
||||||
compute_encoder.set_bytes(split_k_partitions, 2);
|
|
||||||
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
|
||||||
compute_encoder.set_bytes(N, 4);
|
|
||||||
compute_encoder.set_input_array(c, 5);
|
|
||||||
compute_encoder.set_bytes(ldc, 6);
|
|
||||||
compute_encoder.set_bytes(fdc, 7);
|
|
||||||
compute_encoder.set_bytes(alpha_, 8);
|
|
||||||
compute_encoder.set_bytes(beta_, 9);
|
|
||||||
|
|
||||||
// Launch enough thread groups for each output
|
|
||||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
|
||||||
auto group_dims = get_block_dims(N, M, 1);
|
|
||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user