mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Option to JIT steel gemm / conv (#1139)
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@@ -336,7 +337,19 @@ void steel_matmul_conv_groups(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -458,17 +471,31 @@ void steel_matmul(
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_splitk_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
a,
|
||||
C_split,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -504,10 +531,11 @@ void steel_matmul(
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split));
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, false);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
@@ -587,7 +615,19 @@ void steel_matmul(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_splitk_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
a,
|
||||
C_split,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axpby");
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axbpy";
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
@@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Determine dispatch kernel
|
||||
int bm = block_size_, bn = block_size_, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
@@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_masked_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
||||
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
@@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user