Update GEMM (#424)

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
This commit is contained in:
Jagrit Digani
2024-01-17 12:42:39 -08:00
committed by GitHub
parent 556cdf0e06
commit 78102a47ad
30 changed files with 2361 additions and 646 deletions

View File

@@ -70,7 +70,7 @@ void explicit_gemm_conv_1D_gpu(
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
return steel_matmul(
s,
d,
/*a = */ in_strided,
@@ -262,7 +262,7 @@ void explicit_gemm_conv_2D_gpu(
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
return steel_matmul(
s,
d,
/*a = */ in_strided,
@@ -411,7 +411,7 @@ void winograd_conv_2D_gpu(
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
mlx_matmul(
steel_matmul(
s,
d,
/*a = */ inp_wg,