mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 08:38:12 +08:00 
			
		
		
		
	
		
			
	
	
		
			117 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			117 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import argparse | ||
|  | import math | ||
|  | import time | ||
|  | 
 | ||
|  | import mlx.core as mx | ||
|  | import numpy as np | ||
|  | import torch | ||
|  | 
 | ||
|  | N_warmup = 1 | ||
|  | N_iter_bench = 10 | ||
|  | N_iter_func = 5 | ||
|  | mx.set_default_device(mx.cpu) | ||
|  | 
 | ||
|  | 
 | ||
|  | def bench(f, a, b): | ||
|  |     for i in range(N_warmup): | ||
|  |         f(a, b) | ||
|  | 
 | ||
|  |     s = time.perf_counter_ns() | ||
|  |     for i in range(N_iter_bench): | ||
|  |         f(a, b) | ||
|  |     e = time.perf_counter_ns() | ||
|  |     return (e - s) * 1e-9 | ||
|  | 
 | ||
|  | 
 | ||
|  | def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): | ||
|  |     def mx_conv_3D(a, b): | ||
|  |         ys = [] | ||
|  |         for i in range(N_iter_func): | ||
|  |             y = mx.conv_transpose3d( | ||
|  |                 a, b, stride=strides, padding=padding, groups=groups | ||
|  |             ) | ||
|  |             ys.append(y) | ||
|  |         mx.eval(ys) | ||
|  |         return ys | ||
|  | 
 | ||
|  |     return mx_conv_3D | ||
|  | 
 | ||
|  | 
 | ||
|  | def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): | ||
|  |     @torch.no_grad() | ||
|  |     def pt_conv_3D(a, b): | ||
|  |         ys = [] | ||
|  |         for i in range(N_iter_func): | ||
|  |             y = torch.conv_transpose3d( | ||
|  |                 a, b, stride=strides, padding=padding, groups=groups | ||
|  |             ) | ||
|  |             ys.append(y) | ||
|  |         return ys | ||
|  | 
 | ||
|  |     return pt_conv_3D | ||
|  | 
 | ||
|  | 
 | ||
|  | def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): | ||
|  |     scale = 1.0 / math.sqrt(kD * kH * kW * C) | ||
|  |     a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) | ||
|  |     b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( | ||
|  |         np_dtype | ||
|  |     ) | ||
|  | 
 | ||
|  |     a_mx = mx.array(a_np) | ||
|  |     b_mx = mx.array(b_np) | ||
|  | 
 | ||
|  |     a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") | ||
|  |     b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu") | ||
|  | 
 | ||
|  |     f_mx = make_mx_conv_3D(strides, padding, groups) | ||
|  |     f_pt = make_pt_conv_3D(strides, padding, groups) | ||
|  | 
 | ||
|  |     time_torch = bench(f_pt, a_pt, b_pt) | ||
|  |     time_mlx = bench(f_mx, a_mx, b_mx) | ||
|  | 
 | ||
|  |     out_mx = mx.conv_transpose3d( | ||
|  |         a_mx, b_mx, stride=strides, padding=padding, groups=groups | ||
|  |     ) | ||
|  |     out_pt = torch.conv_transpose3d( | ||
|  |         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||
|  |     ) | ||
|  |     out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) | ||
|  |     out_pt = out_pt.numpy(force=True) | ||
|  | 
 | ||
|  |     atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||
|  | 
 | ||
|  |     if not np.allclose(out_pt, out_mx, atol=atol): | ||
|  |         print( | ||
|  |             f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||
|  |         ) | ||
|  | 
 | ||
|  |     return time_mlx, time_torch | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||
|  | 
 | ||
|  |     dtypes = ("float32",) | ||
|  |     shapes = ( | ||
|  |         (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), | ||
|  |         (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), | ||
|  |     ) | ||
|  | 
 | ||
|  |     for dtype in dtypes: | ||
|  |         print( | ||
|  |             "(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%" | ||
|  |         ) | ||
|  |         for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: | ||
|  |             np_dtype = getattr(np, dtype) | ||
|  |             time_mlx, time_torch = bench_shape( | ||
|  |                 N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype | ||
|  |             ) | ||
|  |             diff = time_torch / time_mlx - 1.0 | ||
|  | 
 | ||
|  |             print( | ||
|  |                 f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" | ||
|  |             ) | ||
|  |             if time_mlx >= 2.0 * time_torch: | ||
|  |                 print("ATTENTION ^^^^^^^") |