mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add axpby routing to steel_matmul_regular
This commit is contained in:
parent
13eccfa887
commit
9dbaa35be3
@ -209,8 +209,8 @@ void steel_matmul_regular_axpby(
|
|||||||
std::string base_name = kname.str();
|
std::string base_name = kname.str();
|
||||||
|
|
||||||
const bool has_batch = (batch_shape.size() > 1);
|
const bool has_batch = (batch_shape.size() > 1);
|
||||||
const bool use_out_source = false;
|
const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
|
||||||
const bool do_axpby = false;
|
const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
|
||||||
const bool align_M = (M % bm) == 0;
|
const bool align_M = (M % bm) == 0;
|
||||||
const bool align_N = (N % bn) == 0;
|
const bool align_N = (N % bn) == 0;
|
||||||
const bool align_K = (K % bk) == 0;
|
const bool align_K = (K % bk) == 0;
|
||||||
@ -294,6 +294,21 @@ void steel_matmul_regular_axpby(
|
|||||||
compute_encoder.set_vector_bytes(batch_shape, 6);
|
compute_encoder.set_vector_bytes(batch_shape, 6);
|
||||||
compute_encoder.set_vector_bytes(batch_strides, 7);
|
compute_encoder.set_vector_bytes(batch_strides, 7);
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
int ldc = c.strides()[c.ndim() - 2];
|
||||||
|
int fdc = c.strides()[c.ndim() - 1];
|
||||||
|
|
||||||
|
GEMMAddMMParams params{
|
||||||
|
/* const int ldc = */ ldc,
|
||||||
|
/* const int fdc = */ fdc,
|
||||||
|
/* const int64_t batch_stride_c = */ C_batch_stride,
|
||||||
|
/* const float alpha = */ alpha,
|
||||||
|
/* const float beta = */ beta};
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(c, 2);
|
||||||
|
compute_encoder.set_bytes(params, 5);
|
||||||
|
}
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
// Record copies
|
// Record copies
|
||||||
|
Loading…
Reference in New Issue
Block a user