mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	 c4a471c99d
			
		
	
	c4a471c99d
	
	
	
		
			
			* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
		
			
				
	
	
		
			124 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			124 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import math
 | |
| import os
 | |
| import subprocess
 | |
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| import numpy as np
 | |
| import torch
 | |
| 
 | |
| device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
 | |
| device_name = device_name.decode("utf-8").strip("\n")
 | |
| 
 | |
| N_warmup = 10
 | |
| N_iter_bench = 100
 | |
| N_iter_func = 5
 | |
| 
 | |
| 
 | |
| def bench(f, a, b):
 | |
|     for i in range(N_warmup):
 | |
|         f(a, b)
 | |
|     torch.mps.synchronize()
 | |
| 
 | |
|     s = time.perf_counter_ns()
 | |
|     for i in range(N_iter_bench):
 | |
|         f(a, b)
 | |
|     e = time.perf_counter_ns()
 | |
|     return (e - s) * 1e-9
 | |
| 
 | |
| 
 | |
| def make_mx_conv_1D(strides=1, padding=0, groups=1):
 | |
|     def mx_conv_1D(a, b):
 | |
|         ys = []
 | |
|         for _ in range(N_iter_func):
 | |
|             y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
 | |
|             ys.append(y)
 | |
|         mx.eval(ys)
 | |
|         return ys
 | |
| 
 | |
|     return mx_conv_1D
 | |
| 
 | |
| 
 | |
| def make_pt_conv_1D(strides=1, padding=0, groups=1):
 | |
|     @torch.no_grad()
 | |
|     def pt_conv_1D(a, b):
 | |
|         ys = []
 | |
|         for _ in range(N_iter_func):
 | |
|             y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
 | |
|             ys.append(y)
 | |
|         torch.mps.synchronize()
 | |
|         return ys
 | |
| 
 | |
|     return pt_conv_1D
 | |
| 
 | |
| 
 | |
| def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
 | |
|     scale = 1.0 / math.sqrt(wH * C)
 | |
|     a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
 | |
|     b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
 | |
| 
 | |
|     a_mx = mx.array(a_np)
 | |
|     b_mx = mx.array(b_np)
 | |
| 
 | |
|     a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
 | |
|     b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
 | |
| 
 | |
|     torch.mps.synchronize()
 | |
| 
 | |
|     f_mx = make_mx_conv_1D(strides, padding, groups)
 | |
|     f_pt = make_pt_conv_1D(strides, padding, groups)
 | |
| 
 | |
|     time_torch = bench(f_pt, a_pt, b_pt)
 | |
|     time_mlx = bench(f_mx, a_mx, b_mx)
 | |
| 
 | |
|     out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
 | |
|     out_pt = torch.conv1d(
 | |
|         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
 | |
|     )
 | |
|     out_pt = torch.permute(out_pt, (0, 2, 1))
 | |
|     out_pt = out_pt.numpy(force=True)
 | |
| 
 | |
|     atol = 2e-5 if np_dtype == np.float32 else 1e-4
 | |
| 
 | |
|     if not np.allclose(out_pt, out_mx, atol=atol):
 | |
|         print(
 | |
|             f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
 | |
|         )
 | |
| 
 | |
|     return time_mlx, time_torch
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Run conv benchmarks")
 | |
| 
 | |
|     dtypes = ("float32",)
 | |
|     shapes = (
 | |
|         (4, 32, 32, 5, 32, 1, 2, 1),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 2),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 4),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 8),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 8),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 16),
 | |
|         (4, 32, 32, 5, 32, 1, 2, 32),
 | |
|         (4, 32, 256, 5, 512, 1, 2, 2),
 | |
|         (4, 32, 256, 5, 512, 1, 2, 128),
 | |
|         (4, 32, 256, 5, 512, 1, 2, 256),
 | |
|     )
 | |
| 
 | |
|     for dtype in dtypes:
 | |
|         print("(N,  iH,  C),  (O,  wH,  C),   dtype,  stride, pads, groups, diff%")
 | |
|         for N, iH, C, wH, O, strides, padding, groups in shapes:
 | |
|             np_dtype = getattr(np, dtype)
 | |
|             time_mlx, time_torch = bench_shape(
 | |
|                 N, iH, C, wH, O, strides, padding, np_dtype, groups
 | |
|             )
 | |
|             diff = time_torch / time_mlx - 1.0
 | |
| 
 | |
|             print(
 | |
|                 f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
 | |
|             )
 | |
| 
 | |
|             if time_mlx >= 2.0 * time_torch:
 | |
|                 print("ATTENTION ^^^^^^^")
 |