Add load_safe to the general conv loaders (#2258)

This commit is contained in:
Angelos Katharopoulos
2025-06-10 20:58:16 -07:00
committed by GitHub
parent 095163b8d1
commit 8590c0941e
8 changed files with 302 additions and 22 deletions

View File

@@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
bool align_C = conv_params.C % bk == 0;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
@@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
std::string kname;
kname.reserve(64);
concatenate(
kname,
"implicit_gemm_conv_2d_general_",
type_to_name(out),
"_bm",
bm,
"_bn",
bn,
"_bk",
bk,
"_wm",
wm,
"_wn",
wn);
std::string hash_name;
hash_name.reserve(64);
concatenate(hash_name, kname, "_alC_", align_C);
metal::MTLFCList func_consts = {
{&align_C, MTL::DataType::DataTypeBool, 200},
};
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
auto kernel = get_steel_conv_general_kernel(
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
@@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
bool out_large =
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
@@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}