Comments and format

This commit is contained in:
Jagrit Digani 2025-06-11 09:35:52 -07:00
parent dd5e833068
commit 4b02d3e738

View File

@ -201,10 +201,15 @@ void steel_matmul_regular_axpby(
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
// clang-format off
kname << "steel_gemm_fused_"
<< (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a)
<< "_" << type_to_name(out)
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn; // clang-format on
std::string base_name = kname.str();
@ -237,18 +242,18 @@ void steel_matmul_regular_axpby(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
/* metal::Device& d = */ d,
/* const std::string& kernel_name = */ base_name,
/* const std::string& hash_name = */ hash_name,
/* const metal::MTLFCList& func_consts = */ func_consts,
/* const array& out = */ out,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* int bm = */ bm,
/* int bn = */ bn,
/* int bk = */ bk,
/* int wm = */ wm,
/* int wn = */ wn);
compute_encoder.set_compute_pipeline_state(kernel);
@ -722,9 +727,12 @@ void gemv_axbpy(
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
// clang-format off
kname << "_bm" << bm << "_bn" << bn
<< "_sm" << sm << "_sn" << sn
<< "_tm" << tm << "_tn" << tn
<< "_nc" << !contiguous_kernel
<< "_axpby" << do_axpby; // clang-format on
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);