Update GEMM (#424)

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
This commit is contained in:
Jagrit Digani
2024-01-17 12:42:39 -08:00
committed by GitHub
parent 556cdf0e06
commit 78102a47ad
30 changed files with 2361 additions and 646 deletions

View File

@@ -166,13 +166,13 @@ if __name__ == "__main__":
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes: