Comments and format

This commit is contained in:
Jagrit Digani 2025-06-11 09:35:52 -07:00
parent dd5e833068
commit 4b02d3e738

View File

@ -201,10 +201,15 @@ void steel_matmul_regular_axpby(
// Prepare kernel name // Prepare kernel name
std::ostringstream kname; std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" // clang-format off
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk kname << "steel_gemm_fused_"
<< "_wm" << wm << "_wn" << wn; << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str(); std::string base_name = kname.str();
@ -237,18 +242,18 @@ void steel_matmul_regular_axpby(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel( auto kernel = get_steel_gemm_fused_kernel(
d, /* metal::Device& d = */ d,
base_name, /* const std::string& kernel_name = */ base_name,
hash_name, /* const std::string& hash_name = */ hash_name,
func_consts, /* const metal::MTLFCList& func_consts = */ func_consts,
out, /* const array& out = */ out,
transpose_a, /* bool transpose_a = */ transpose_a,
transpose_b, /* bool transpose_b = */ transpose_b,
bm, /* int bm = */ bm,
bn, /* int bn = */ bn,
bk, /* int bk = */ bk,
wm, /* int wm = */ wm,
wn); /* int wn = */ wn);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@ -722,9 +727,12 @@ void gemv_axbpy(
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f); const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" // clang-format off
<< tm << "_tn" << tn; kname << "_bm" << bm << "_bn" << bn
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby; << "_sm" << sm << "_sn" << sn
<< "_tm" << tm << "_tn" << tn
<< "_nc" << !contiguous_kernel
<< "_axpby" << do_axpby; // clang-format on
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);