mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Conv cpu improvements (#1410)
This commit is contained in:
		 Max-Heinrich Laves
					Max-Heinrich Laves
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							d6492b0163
						
					
				
				
					commit
					adcc88e208
				
			
							
								
								
									
										127
									
								
								benchmarks/python/conv2d_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								benchmarks/python/conv2d_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | ||||
| import argparse | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| N_warmup = 1 | ||||
| N_iter_bench = 10 | ||||
| N_iter_func = 5 | ||||
| mx.set_default_device(mx.cpu) | ||||
|  | ||||
|  | ||||
| def bench(f, a, b): | ||||
|     for i in range(N_warmup): | ||||
|         f(a, b) | ||||
|  | ||||
|     s = time.perf_counter_ns() | ||||
|     for i in range(N_iter_bench): | ||||
|         f(a, b) | ||||
|     e = time.perf_counter_ns() | ||||
|     return (e - s) * 1e-9 | ||||
|  | ||||
|  | ||||
| def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     def mx_conv_2D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) | ||||
|             ys.append(y) | ||||
|         mx.eval(ys) | ||||
|         return ys | ||||
|  | ||||
|     return mx_conv_2D | ||||
|  | ||||
|  | ||||
| def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     @torch.no_grad() | ||||
|     def pt_conv_2D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) | ||||
|             ys.append(y) | ||||
|         return ys | ||||
|  | ||||
|     return pt_conv_2D | ||||
|  | ||||
|  | ||||
| def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): | ||||
|     scale = 1.0 / math.sqrt(kH * kH * C) | ||||
|     a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) | ||||
|     b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( | ||||
|         np_dtype | ||||
|     ) | ||||
|  | ||||
|     a_mx = mx.array(a_np) | ||||
|     b_mx = mx.array(b_np) | ||||
|  | ||||
|     a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") | ||||
|     b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu") | ||||
|  | ||||
|     f_mx = make_mx_conv_2D(strides, padding, groups) | ||||
|     f_pt = make_pt_conv_2D(strides, padding, groups) | ||||
|  | ||||
|     time_torch = bench(f_pt, a_pt, b_pt) | ||||
|     time_mlx = bench(f_mx, a_mx, b_mx) | ||||
|  | ||||
|     out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) | ||||
|     out_pt = torch.conv2d( | ||||
|         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||||
|     ) | ||||
|     out_pt = torch.permute(out_pt, (0, 2, 3, 1)) | ||||
|     out_pt = out_pt.numpy(force=True) | ||||
|  | ||||
|     atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||||
|  | ||||
|     if not np.allclose(out_pt, out_mx, atol=atol): | ||||
|         print( | ||||
|             f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||||
|         ) | ||||
|  | ||||
|     return time_mlx, time_torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||||
|  | ||||
|     dtypes = ("float32",) | ||||
|     shapes = ( | ||||
|         (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), | ||||
|         # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2), | ||||
|         # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16), | ||||
|         # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64), | ||||
|         (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), | ||||
|         (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|     ) | ||||
|  | ||||
|     for dtype in dtypes: | ||||
|         print( | ||||
|             "(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%" | ||||
|         ) | ||||
|         for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: | ||||
|             np_dtype = getattr(np, dtype) | ||||
|             time_mlx, time_torch = bench_shape( | ||||
|                 N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype | ||||
|             ) | ||||
|             diff = time_torch / time_mlx - 1.0 | ||||
|  | ||||
|             print( | ||||
|                 f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" | ||||
|             ) | ||||
|             if time_mlx >= 2.0 * time_torch: | ||||
|                 print("ATTENTION ^^^^^^^") | ||||
							
								
								
									
										143
									
								
								benchmarks/python/conv2d_train_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								benchmarks/python/conv2d_train_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn | ||||
| import mlx.optimizers as opt | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def bench_mlx(steps: int = 20) -> float: | ||||
|     mx.set_default_device(mx.cpu) | ||||
|  | ||||
|     class BenchNetMLX(mlx.nn.Module): | ||||
|         # simple encoder-decoder net | ||||
|  | ||||
|         def __init__(self, in_channels, hidden_channels=32): | ||||
|             super().__init__() | ||||
|  | ||||
|             self.net = mlx.nn.Sequential( | ||||
|                 mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.Conv2d( | ||||
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.ConvTranspose2d( | ||||
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.ConvTranspose2d( | ||||
|                     hidden_channels, in_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         def __call__(self, input): | ||||
|             return self.net(input) | ||||
|  | ||||
|     benchNet = BenchNetMLX(3) | ||||
|     mx.eval(benchNet.parameters()) | ||||
|     optim = opt.Adam(learning_rate=1e-3) | ||||
|  | ||||
|     inputs = mx.random.normal([10, 256, 256, 3]) | ||||
|  | ||||
|     params = benchNet.parameters() | ||||
|     optim.init(params) | ||||
|  | ||||
|     state = [benchNet.state, optim.state] | ||||
|  | ||||
|     def loss_fn(params, image): | ||||
|         benchNet.update(params) | ||||
|         pred_image = benchNet(image) | ||||
|         return (pred_image - image).abs().mean() | ||||
|  | ||||
|     def step(params, image): | ||||
|         loss, grads = mx.value_and_grad(loss_fn)(params, image) | ||||
|         optim.update(benchNet, grads) | ||||
|         return loss | ||||
|  | ||||
|     total_time = 0.0 | ||||
|     print("MLX:") | ||||
|     for i in range(steps): | ||||
|         start_time = time.perf_counter() | ||||
|  | ||||
|         step(benchNet.parameters(), inputs) | ||||
|         mx.eval(state) | ||||
|         end_time = time.perf_counter() | ||||
|  | ||||
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") | ||||
|         total_time += (end_time - start_time) * 1000 | ||||
|  | ||||
|     return total_time | ||||
|  | ||||
|  | ||||
| def bench_torch(steps: int = 20) -> float: | ||||
|     device = torch.device("cpu") | ||||
|  | ||||
|     class BenchNetTorch(torch.nn.Module): | ||||
|         # simple encoder-decoder net | ||||
|  | ||||
|         def __init__(self, in_channels, hidden_channels=32): | ||||
|             super().__init__() | ||||
|  | ||||
|             self.net = torch.nn.Sequential( | ||||
|                 torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.Conv2d( | ||||
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.ConvTranspose2d( | ||||
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.ConvTranspose2d( | ||||
|                     hidden_channels, in_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         def forward(self, input): | ||||
|             return self.net(input) | ||||
|  | ||||
|     benchNet = BenchNetTorch(3).to(device) | ||||
|     optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) | ||||
|  | ||||
|     inputs = torch.randn(10, 3, 256, 256, device=device) | ||||
|  | ||||
|     def loss_fn(pred_image, image): | ||||
|         return (pred_image - image).abs().mean() | ||||
|  | ||||
|     total_time = 0.0 | ||||
|     print("PyTorch:") | ||||
|     for i in range(steps): | ||||
|         start_time = time.perf_counter() | ||||
|  | ||||
|         optim.zero_grad() | ||||
|         pred_image = benchNet(inputs) | ||||
|         loss = loss_fn(pred_image, inputs) | ||||
|         loss.backward() | ||||
|         optim.step() | ||||
|  | ||||
|         end_time = time.perf_counter() | ||||
|  | ||||
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") | ||||
|         total_time += (end_time - start_time) * 1000 | ||||
|  | ||||
|     return total_time | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     steps = 20 | ||||
|     time_mlx = bench_mlx(steps) | ||||
|     time_torch = bench_torch(steps) | ||||
|  | ||||
|     print(f"average time of MLX:     {time_mlx/steps:9.2f} ms") | ||||
|     print(f"total time of MLX:       {time_mlx:9.2f} ms") | ||||
|     print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") | ||||
|     print(f"total time of PyTorch:   {time_torch:9.2f} ms") | ||||
|  | ||||
|     diff = time_torch / time_mlx - 1.0 | ||||
|     print(f"torch/mlx diff: {100. * diff:+5.2f}%") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										129
									
								
								benchmarks/python/conv2d_transpose_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								benchmarks/python/conv2d_transpose_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| import argparse | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| N_warmup = 1 | ||||
| N_iter_bench = 10 | ||||
| N_iter_func = 5 | ||||
|  | ||||
|  | ||||
| def bench(f, a, b): | ||||
|     for i in range(N_warmup): | ||||
|         f(a, b) | ||||
|  | ||||
|     s = time.perf_counter_ns() | ||||
|     for i in range(N_iter_bench): | ||||
|         f(a, b) | ||||
|     e = time.perf_counter_ns() | ||||
|     return (e - s) * 1e-9 | ||||
|  | ||||
|  | ||||
| def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     def mx_conv_transpose_2D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = mx.conv_transpose2d( | ||||
|                 a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu | ||||
|             ) | ||||
|             ys.append(y) | ||||
|         mx.eval(ys) | ||||
|         return ys | ||||
|  | ||||
|     return mx_conv_transpose_2D | ||||
|  | ||||
|  | ||||
| def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     @torch.no_grad() | ||||
|     def pt_conv_transpose_2D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = torch.conv_transpose2d( | ||||
|                 a, b, stride=strides, padding=padding, groups=groups | ||||
|             ) | ||||
|             ys.append(y) | ||||
|         return ys | ||||
|  | ||||
|     return pt_conv_transpose_2D | ||||
|  | ||||
|  | ||||
| def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): | ||||
|     scale = 1.0 / math.sqrt(kH * kH * C) | ||||
|     a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) | ||||
|     b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype( | ||||
|         np_dtype | ||||
|     ) | ||||
|  | ||||
|     a_mx = mx.array(a_np) | ||||
|     b_mx = mx.array(b_np) | ||||
|  | ||||
|     a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") | ||||
|     b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu") | ||||
|  | ||||
|     f_mx = make_mx_conv_transpose_2D(strides, padding, groups) | ||||
|     f_pt = make_pt_conv_transpose_2D(strides, padding, groups) | ||||
|  | ||||
|     time_torch = bench(f_pt, a_pt, b_pt) | ||||
|     time_mlx = bench(f_mx, a_mx, b_mx) | ||||
|  | ||||
|     out_mx = mx.conv_transpose2d( | ||||
|         a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu | ||||
|     ) | ||||
|     out_pt = torch.conv_transpose2d( | ||||
|         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||||
|     ) | ||||
|     out_pt = torch.permute(out_pt, (0, 2, 3, 1)) | ||||
|     out_pt = out_pt.numpy(force=True) | ||||
|  | ||||
|     atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||||
|  | ||||
|     if not np.allclose(out_pt, out_mx, atol=atol): | ||||
|         print( | ||||
|             f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||||
|         ) | ||||
|  | ||||
|     return time_mlx, time_torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||||
|  | ||||
|     dtypes = ("float32",) | ||||
|     shapes = ( | ||||
|         (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), | ||||
|         (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), | ||||
|         (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), | ||||
|         (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), | ||||
|         (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), | ||||
|     ) | ||||
|  | ||||
|     for dtype in dtypes: | ||||
|         print( | ||||
|             "(N,   H,   W,   C), (  O, kH, kW,   C),   dtype, stride,   pads,  groups, diff%" | ||||
|         ) | ||||
|         for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: | ||||
|             np_dtype = getattr(np, dtype) | ||||
|             time_mlx, time_torch = bench_shape( | ||||
|                 N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype | ||||
|             ) | ||||
|             diff = time_torch / time_mlx - 1.0 | ||||
|  | ||||
|             print( | ||||
|                 f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" | ||||
|             ) | ||||
|             if time_mlx >= 2.0 * time_torch: | ||||
|                 print("ATTENTION ^^^^^^^") | ||||
							
								
								
									
										110
									
								
								benchmarks/python/conv3d_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								benchmarks/python/conv3d_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| import argparse | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| N_warmup = 1 | ||||
| N_iter_bench = 10 | ||||
| N_iter_func = 5 | ||||
| mx.set_default_device(mx.cpu) | ||||
|  | ||||
|  | ||||
| def bench(f, a, b): | ||||
|     for i in range(N_warmup): | ||||
|         f(a, b) | ||||
|  | ||||
|     s = time.perf_counter_ns() | ||||
|     for i in range(N_iter_bench): | ||||
|         f(a, b) | ||||
|     e = time.perf_counter_ns() | ||||
|     return (e - s) * 1e-9 | ||||
|  | ||||
|  | ||||
| def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     def mx_conv_3D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups) | ||||
|             ys.append(y) | ||||
|         mx.eval(ys) | ||||
|         return ys | ||||
|  | ||||
|     return mx_conv_3D | ||||
|  | ||||
|  | ||||
| def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): | ||||
|     @torch.no_grad() | ||||
|     def pt_conv_3D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups) | ||||
|             ys.append(y) | ||||
|         return ys | ||||
|  | ||||
|     return pt_conv_3D | ||||
|  | ||||
|  | ||||
| def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): | ||||
|     scale = 1.0 / math.sqrt(kD * kH * kW * C) | ||||
|     a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) | ||||
|     b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( | ||||
|         np_dtype | ||||
|     ) | ||||
|  | ||||
|     a_mx = mx.array(a_np) | ||||
|     b_mx = mx.array(b_np) | ||||
|  | ||||
|     a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") | ||||
|     b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu") | ||||
|  | ||||
|     f_mx = make_mx_conv_3D(strides, padding, groups) | ||||
|     f_pt = make_pt_conv_3D(strides, padding, groups) | ||||
|  | ||||
|     time_torch = bench(f_pt, a_pt, b_pt) | ||||
|     time_mlx = bench(f_mx, a_mx, b_mx) | ||||
|  | ||||
|     out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) | ||||
|     out_pt = torch.conv3d( | ||||
|         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||||
|     ) | ||||
|     out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) | ||||
|     out_pt = out_pt.numpy(force=True) | ||||
|  | ||||
|     atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||||
|  | ||||
|     if not np.allclose(out_pt, out_mx, atol=atol): | ||||
|         print( | ||||
|             f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||||
|         ) | ||||
|  | ||||
|     return time_mlx, time_torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||||
|  | ||||
|     dtypes = ("float32",) | ||||
|     shapes = ( | ||||
|         (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), | ||||
|         (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), | ||||
|     ) | ||||
|  | ||||
|     for dtype in dtypes: | ||||
|         print( | ||||
|             "(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%" | ||||
|         ) | ||||
|         for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: | ||||
|             np_dtype = getattr(np, dtype) | ||||
|             time_mlx, time_torch = bench_shape( | ||||
|                 N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype | ||||
|             ) | ||||
|             diff = time_torch / time_mlx - 1.0 | ||||
|  | ||||
|             print( | ||||
|                 f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" | ||||
|             ) | ||||
|             if time_mlx >= 2.0 * time_torch: | ||||
|                 print("ATTENTION ^^^^^^^") | ||||
							
								
								
									
										143
									
								
								benchmarks/python/conv3d_train_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								benchmarks/python/conv3d_train_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn | ||||
| import mlx.optimizers as opt | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float: | ||||
|     mx.set_default_device(mx.cpu) | ||||
|  | ||||
|     class BenchNetMLX(mlx.nn.Module): | ||||
|         # simple encoder-decoder net | ||||
|  | ||||
|         def __init__(self, in_channels, hidden_channels=16): | ||||
|             super().__init__() | ||||
|  | ||||
|             self.net = mlx.nn.Sequential( | ||||
|                 mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.Conv3d( | ||||
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.ConvTranspose3d( | ||||
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 mlx.nn.ReLU(), | ||||
|                 mlx.nn.ConvTranspose3d( | ||||
|                     hidden_channels, in_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         def __call__(self, input): | ||||
|             return self.net(input) | ||||
|  | ||||
|     benchNet = BenchNetMLX(3) | ||||
|     mx.eval(benchNet.parameters()) | ||||
|     optim = opt.Adam(learning_rate=1e-3) | ||||
|  | ||||
|     inputs = mx.random.normal(shape) | ||||
|  | ||||
|     params = benchNet.parameters() | ||||
|     optim.init(params) | ||||
|  | ||||
|     state = [benchNet.state, optim.state] | ||||
|  | ||||
|     def loss_fn(params, image): | ||||
|         benchNet.update(params) | ||||
|         pred_image = benchNet(image) | ||||
|         return (pred_image - image).abs().mean() | ||||
|  | ||||
|     def step(params, image): | ||||
|         loss, grads = mx.value_and_grad(loss_fn)(params, image) | ||||
|         optim.update(benchNet, grads) | ||||
|         return loss | ||||
|  | ||||
|     total_time = 0.0 | ||||
|     print("MLX:") | ||||
|     for i in range(steps): | ||||
|         start_time = time.perf_counter() | ||||
|  | ||||
|         step(benchNet.parameters(), inputs) | ||||
|         mx.eval(state) | ||||
|         end_time = time.perf_counter() | ||||
|  | ||||
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") | ||||
|         total_time += (end_time - start_time) * 1000 | ||||
|  | ||||
|     return total_time | ||||
|  | ||||
|  | ||||
| def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float: | ||||
|     device = torch.device("cpu") | ||||
|  | ||||
|     class BenchNetTorch(torch.nn.Module): | ||||
|         # simple encoder-decoder net | ||||
|  | ||||
|         def __init__(self, in_channels, hidden_channels=16): | ||||
|             super().__init__() | ||||
|  | ||||
|             self.net = torch.nn.Sequential( | ||||
|                 torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.Conv3d( | ||||
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.ConvTranspose3d( | ||||
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|                 torch.nn.ReLU(), | ||||
|                 torch.nn.ConvTranspose3d( | ||||
|                     hidden_channels, in_channels, kernel_size=3, padding=1 | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         def forward(self, input): | ||||
|             return self.net(input) | ||||
|  | ||||
|     benchNet = BenchNetTorch(3).to(device) | ||||
|     optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) | ||||
|  | ||||
|     inputs = torch.randn(*shape, device=device) | ||||
|  | ||||
|     def loss_fn(pred_image, image): | ||||
|         return (pred_image - image).abs().mean() | ||||
|  | ||||
|     total_time = 0.0 | ||||
|     print("PyTorch:") | ||||
|     for i in range(steps): | ||||
|         start_time = time.perf_counter() | ||||
|  | ||||
|         optim.zero_grad() | ||||
|         pred_image = benchNet(inputs) | ||||
|         loss = loss_fn(pred_image, inputs) | ||||
|         loss.backward() | ||||
|         optim.step() | ||||
|  | ||||
|         end_time = time.perf_counter() | ||||
|  | ||||
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") | ||||
|         total_time += (end_time - start_time) * 1000 | ||||
|  | ||||
|     return total_time | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     steps = 10 | ||||
|     time_mlx = bench_mlx(steps) | ||||
|     time_torch = bench_torch(steps) | ||||
|  | ||||
|     print(f"average time of MLX:     {time_mlx/steps:9.2f} ms") | ||||
|     print(f"total time of MLX:       {time_mlx:9.2f} ms") | ||||
|     print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") | ||||
|     print(f"total time of PyTorch:   {time_torch:9.2f} ms") | ||||
|  | ||||
|     diff = time_torch / time_mlx - 1.0 | ||||
|     print(f"torch/mlx diff: {100. * diff:+5.2f}%") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										116
									
								
								benchmarks/python/conv3d_transpose_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								benchmarks/python/conv3d_transpose_bench_cpu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | ||||
| import argparse | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
| N_warmup = 1 | ||||
| N_iter_bench = 10 | ||||
| N_iter_func = 5 | ||||
| mx.set_default_device(mx.cpu) | ||||
|  | ||||
|  | ||||
| def bench(f, a, b): | ||||
|     for i in range(N_warmup): | ||||
|         f(a, b) | ||||
|  | ||||
|     s = time.perf_counter_ns() | ||||
|     for i in range(N_iter_bench): | ||||
|         f(a, b) | ||||
|     e = time.perf_counter_ns() | ||||
|     return (e - s) * 1e-9 | ||||
|  | ||||
|  | ||||
| def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): | ||||
|     def mx_conv_3D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = mx.conv_transpose3d( | ||||
|                 a, b, stride=strides, padding=padding, groups=groups | ||||
|             ) | ||||
|             ys.append(y) | ||||
|         mx.eval(ys) | ||||
|         return ys | ||||
|  | ||||
|     return mx_conv_3D | ||||
|  | ||||
|  | ||||
| def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): | ||||
|     @torch.no_grad() | ||||
|     def pt_conv_3D(a, b): | ||||
|         ys = [] | ||||
|         for i in range(N_iter_func): | ||||
|             y = torch.conv_transpose3d( | ||||
|                 a, b, stride=strides, padding=padding, groups=groups | ||||
|             ) | ||||
|             ys.append(y) | ||||
|         return ys | ||||
|  | ||||
|     return pt_conv_3D | ||||
|  | ||||
|  | ||||
| def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): | ||||
|     scale = 1.0 / math.sqrt(kD * kH * kW * C) | ||||
|     a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) | ||||
|     b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( | ||||
|         np_dtype | ||||
|     ) | ||||
|  | ||||
|     a_mx = mx.array(a_np) | ||||
|     b_mx = mx.array(b_np) | ||||
|  | ||||
|     a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") | ||||
|     b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu") | ||||
|  | ||||
|     f_mx = make_mx_conv_3D(strides, padding, groups) | ||||
|     f_pt = make_pt_conv_3D(strides, padding, groups) | ||||
|  | ||||
|     time_torch = bench(f_pt, a_pt, b_pt) | ||||
|     time_mlx = bench(f_mx, a_mx, b_mx) | ||||
|  | ||||
|     out_mx = mx.conv_transpose3d( | ||||
|         a_mx, b_mx, stride=strides, padding=padding, groups=groups | ||||
|     ) | ||||
|     out_pt = torch.conv_transpose3d( | ||||
|         a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups | ||||
|     ) | ||||
|     out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) | ||||
|     out_pt = out_pt.numpy(force=True) | ||||
|  | ||||
|     atol = 2e-5 if np_dtype == np.float32 else 1e-4 | ||||
|  | ||||
|     if not np.allclose(out_pt, out_mx, atol=atol): | ||||
|         print( | ||||
|             f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" | ||||
|         ) | ||||
|  | ||||
|     return time_mlx, time_torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Run conv benchmarks") | ||||
|  | ||||
|     dtypes = ("float32",) | ||||
|     shapes = ( | ||||
|         (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), | ||||
|         (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), | ||||
|     ) | ||||
|  | ||||
|     for dtype in dtypes: | ||||
|         print( | ||||
|             "(N,   D,   H,   W,   C), (  O, kD, kH, kW,   C),   dtype,    stride,      pads,  groups, diff%" | ||||
|         ) | ||||
|         for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: | ||||
|             np_dtype = getattr(np, dtype) | ||||
|             time_mlx, time_torch = bench_shape( | ||||
|                 N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype | ||||
|             ) | ||||
|             diff = time_torch / time_mlx - 1.0 | ||||
|  | ||||
|             print( | ||||
|                 f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" | ||||
|             ) | ||||
|             if time_mlx >= 2.0 * time_torch: | ||||
|                 print("ATTENTION ^^^^^^^") | ||||
		Reference in New Issue
	
	Block a user