Update GEMM (#424)

* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
This commit is contained in:
Jagrit Digani
2024-01-17 12:42:39 -08:00
committed by GitHub
parent 556cdf0e06
commit 78102a47ad
30 changed files with 2361 additions and 646 deletions

View File

@@ -257,6 +257,13 @@ def linear(w, b, x):
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
@@ -397,7 +404,10 @@ if __name__ == "__main__":
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))