Add axpby routing to steel_matmul_regular

This commit is contained in:
Jagrit Digani 2025-06-11 08:54:42 -07:00
parent 13eccfa887
commit 9dbaa35be3

View File

@ -209,8 +209,8 @@ void steel_matmul_regular_axpby(
std::string base_name = kname.str(); std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1); const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = false; const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
const bool do_axpby = false; const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
const bool align_M = (M % bm) == 0; const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0; const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0; const bool align_K = (K % bk) == 0;
@ -294,6 +294,21 @@ void steel_matmul_regular_axpby(
compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_shape, 6);
compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_vector_bytes(batch_strides, 7);
if (use_out_source) {
int ldc = c.strides()[c.ndim() - 2];
int fdc = c.strides()[c.ndim() - 1];
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int64_t batch_stride_c = */ C_batch_stride,
/* const float alpha = */ alpha,
/* const float beta = */ beta};
compute_encoder.set_input_array(c, 2);
compute_encoder.set_bytes(params, 5);
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
// Record copies // Record copies