diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 0ee189e47..1c9ee5a6c 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -209,8 +209,8 @@ void steel_matmul_regular_axpby( std::string base_name = kname.str(); const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; const bool align_K = (K % bk) == 0; @@ -294,6 +294,21 @@ void steel_matmul_regular_axpby( compute_encoder.set_vector_bytes(batch_shape, 6); compute_encoder.set_vector_bytes(batch_strides, 7); + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies