Option to JIT steel gemm / conv (#1139)

This commit is contained in:
Awni Hannun
2024-05-23 18:07:34 -07:00
committed by GitHub
parent eab2685c67
commit 7e26fd8032
31 changed files with 2504 additions and 1540 deletions

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h"
@@ -336,7 +337,19 @@ void steel_matmul_conv_groups(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
@@ -458,17 +471,31 @@ void steel_matmul(
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
@@ -504,10 +531,11 @@ void steel_matmul(
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel = d.get_kernel(
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split));
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, false);
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
@@ -587,7 +615,19 @@ void steel_matmul(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
@@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
@@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Do accum kernel
{
auto kernel = d.get_kernel(
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split) + "_axpby");
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split) + "_axbpy";
auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true);
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
@@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel);
@@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Determine dispatch kernel
int bm = block_size_, bn = block_size_, bk = 16;
int wm = 2, wn = 2;
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
// Prepare kernel name
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
@@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_gemm_masked_kernel(
d,
kname.str(),
out,
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
@@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel);