mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Comments and format
This commit is contained in:
parent
dd5e833068
commit
4b02d3e738
@ -201,10 +201,15 @@ void steel_matmul_regular_axpby(
|
|||||||
|
|
||||||
// Prepare kernel name
|
// Prepare kernel name
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
|
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
// clang-format off
|
||||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
kname << "steel_gemm_fused_"
|
||||||
<< "_wm" << wm << "_wn" << wn;
|
<< (transpose_a ? 't' : 'n')
|
||||||
|
<< (transpose_b ? 't' : 'n')
|
||||||
|
<< "_" << type_to_name(a)
|
||||||
|
<< "_" << type_to_name(out)
|
||||||
|
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
||||||
|
|
||||||
std::string base_name = kname.str();
|
std::string base_name = kname.str();
|
||||||
|
|
||||||
@ -237,18 +242,18 @@ void steel_matmul_regular_axpby(
|
|||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = get_steel_gemm_fused_kernel(
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
d,
|
/* metal::Device& d = */ d,
|
||||||
base_name,
|
/* const std::string& kernel_name = */ base_name,
|
||||||
hash_name,
|
/* const std::string& hash_name = */ hash_name,
|
||||||
func_consts,
|
/* const metal::MTLFCList& func_consts = */ func_consts,
|
||||||
out,
|
/* const array& out = */ out,
|
||||||
transpose_a,
|
/* bool transpose_a = */ transpose_a,
|
||||||
transpose_b,
|
/* bool transpose_b = */ transpose_b,
|
||||||
bm,
|
/* int bm = */ bm,
|
||||||
bn,
|
/* int bn = */ bn,
|
||||||
bk,
|
/* int bk = */ bk,
|
||||||
wm,
|
/* int wm = */ wm,
|
||||||
wn);
|
/* int wn = */ wn);
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@ -722,9 +727,12 @@ void gemv_axbpy(
|
|||||||
|
|
||||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||||
|
|
||||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
// clang-format off
|
||||||
<< tm << "_tn" << tn;
|
kname << "_bm" << bm << "_bn" << bn
|
||||||
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
|
<< "_sm" << sm << "_sn" << sn
|
||||||
|
<< "_tm" << tm << "_tn" << tn
|
||||||
|
<< "_nc" << !contiguous_kernel
|
||||||
|
<< "_axpby" << do_axpby; // clang-format on
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
Loading…
Reference in New Issue
Block a user