mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -70,7 +70,7 @@ void explicit_gemm_conv_1D_gpu(
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
@@ -262,7 +262,7 @@ void explicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
@@ -411,7 +411,7 @@ void winograd_conv_2D_gpu(
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
mlx_matmul(
|
||||
steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ inp_wg,
|
||||
|
||||
Reference in New Issue
Block a user