mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-21 01:58:10 +08:00 
			
		
		
		
	Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
		| @@ -257,6 +257,13 @@ def linear(w, b, x): | ||||
|     mx.eval(ys) | ||||
|  | ||||
|  | ||||
| def linear_fused(w, b, x): | ||||
|     ys = [] | ||||
|     for i in range(10): | ||||
|         ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0)))) | ||||
|     mx.eval(ys) | ||||
|  | ||||
|  | ||||
| def rope(x): | ||||
|     *_, N, D = x.shape | ||||
|     ys = [] | ||||
| @@ -397,7 +404,10 @@ if __name__ == "__main__": | ||||
|         print(bench(quant_matmul[args.benchmark], *xs)) | ||||
|  | ||||
|     elif args.benchmark == "linear": | ||||
|         print(bench(linear, *xs)) | ||||
|         if args.fused: | ||||
|             print(bench(linear_fused, *xs)) | ||||
|         else: | ||||
|             print(bench(linear, *xs)) | ||||
|  | ||||
|     elif args.benchmark == "sum_axis": | ||||
|         print(bench(reduction, "sum", axis, x)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani