mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 14:58:13 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -166,13 +166,13 @@ if __name__ == "__main__":
|
||||
dtypes = ("float32", "float16")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
(1, 64, 64, 25344),
|
||||
(16, 1024, 1024, 1024),
|
||||
(1, 1024, 1024, 2048),
|
||||
(4, 1024, 1024, 4096),
|
||||
(4, 1024, 4096, 1024),
|
||||
(1, 4096, 4096, 4096),
|
||||
(15, 1023, 1023, 1023),
|
||||
(17, 1025, 1025, 1025),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
|
@@ -257,6 +257,13 @@ def linear(w, b, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def linear_fused(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def rope(x):
|
||||
*_, N, D = x.shape
|
||||
ys = []
|
||||
@@ -397,7 +404,10 @@ if __name__ == "__main__":
|
||||
print(bench(quant_matmul[args.benchmark], *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
if args.fused:
|
||||
print(bench(linear_fused, *xs))
|
||||
else:
|
||||
print(bench(linear, *xs))
|
||||
|
||||
elif args.benchmark == "sum_axis":
|
||||
print(bench(reduction, "sum", axis, x))
|
||||
|
Reference in New Issue
Block a user