mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Comments and format
This commit is contained in:
parent
dd5e833068
commit
4b02d3e738
@ -201,10 +201,15 @@ void steel_matmul_regular_axpby(
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// clang-format off
|
||||
kname << "steel_gemm_fused_"
|
||||
<< (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a)
|
||||
<< "_" << type_to_name(out)
|
||||
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||
|
||||
std::string base_name = kname.str();
|
||||
|
||||
@ -237,18 +242,18 @@ void steel_matmul_regular_axpby(
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
/* metal::Device& d = */ d,
|
||||
/* const std::string& kernel_name = */ base_name,
|
||||
/* const std::string& hash_name = */ hash_name,
|
||||
/* const metal::MTLFCList& func_consts = */ func_consts,
|
||||
/* const array& out = */ out,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* int bm = */ bm,
|
||||
/* int bn = */ bn,
|
||||
/* int bk = */ bk,
|
||||
/* int wm = */ wm,
|
||||
/* int wn = */ wn);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -722,9 +727,12 @@ void gemv_axbpy(
|
||||
|
||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
|
||||
// clang-format off
|
||||
kname << "_bm" << bm << "_bn" << bn
|
||||
<< "_sm" << sm << "_sn" << sn
|
||||
<< "_tm" << tm << "_tn" << tn
|
||||
<< "_nc" << !contiguous_kernel
|
||||
<< "_axpby" << do_axpby; // clang-format on
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
Loading…
Reference in New Issue
Block a user