Move some kernels to get_template_definition (#1782)

This commit is contained in:
Angelos Katharopoulos
2025-01-21 08:59:44 -08:00
committed by GitHub
parent 90532b1f37
commit 1f4c127fb9
4 changed files with 79 additions and 247 deletions

View File

@@ -1,11 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -445,17 +442,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
<< get_template_definition(
lib_name,
"gemm",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
@@ -480,20 +477,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
<< get_template_definition(
lib_name,
"gemm_splitk",
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b,
mn_aligned,
k_aligned);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -510,13 +507,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
<< get_template_definition(
lib_name,
axbpy ? "gemm_splitk_accum_axpby"
: "gemm_splitk_accum",
get_type_string(in.dtype()),
get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -547,21 +543,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
<< get_template_definition(
lib_name,
"block_masked_gemm",
get_type_string(out.dtype()),
out_mask_type,
op_mask_type,
bm,
bn,
bk,
wm,
wn,
transpose_a,
transpose_b,
mn_aligned,
k_aligned);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -590,20 +586,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
<< get_template_definition(
lib_name,
(transpose_mat) ? "gemv_t_masked" : "gemv_masked",
get_type_string(out.dtype()),
out_mask_type,
op_mask_type,
bm,
bn,
sm,
sn,
tm,
tn,
contiguous ? 0 : 1);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -624,17 +619,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
<< get_template_definition(
lib_name,
"implicit_gemm_conv_2d",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
@@ -654,15 +649,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
<< get_template_definition(
lib_name,
"implicit_gemm_conv_2d_general",
get_type_string(out.dtype()),
bm,
bn,
bk,
wm,
wn);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);