mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -257,6 +257,13 @@ def linear(w, b, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def linear_fused(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def rope(x):
|
||||
*_, N, D = x.shape
|
||||
ys = []
|
||||
@@ -397,7 +404,10 @@ if __name__ == "__main__":
|
||||
print(bench(quant_matmul[args.benchmark], *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
if args.fused:
|
||||
print(bench(linear_fused, *xs))
|
||||
else:
|
||||
print(bench(linear, *xs))
|
||||
|
||||
elif args.benchmark == "sum_axis":
|
||||
print(bench(reduction, "sum", axis, x))
|
||||
|
Reference in New Issue
Block a user