mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			144 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			144 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn
 | |
| import mlx.optimizers as opt
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def bench_mlx(steps: int = 20) -> float:
 | |
|     mx.set_default_device(mx.cpu)
 | |
| 
 | |
|     class BenchNetMLX(mlx.nn.Module):
 | |
|         # simple encoder-decoder net
 | |
| 
 | |
|         def __init__(self, in_channels, hidden_channels=32):
 | |
|             super().__init__()
 | |
| 
 | |
|             self.net = mlx.nn.Sequential(
 | |
|                 mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
 | |
|                 mlx.nn.ReLU(),
 | |
|                 mlx.nn.Conv2d(
 | |
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|                 mlx.nn.ReLU(),
 | |
|                 mlx.nn.ConvTranspose2d(
 | |
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|                 mlx.nn.ReLU(),
 | |
|                 mlx.nn.ConvTranspose2d(
 | |
|                     hidden_channels, in_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         def __call__(self, input):
 | |
|             return self.net(input)
 | |
| 
 | |
|     benchNet = BenchNetMLX(3)
 | |
|     mx.eval(benchNet.parameters())
 | |
|     optim = opt.Adam(learning_rate=1e-3)
 | |
| 
 | |
|     inputs = mx.random.normal([10, 256, 256, 3])
 | |
| 
 | |
|     params = benchNet.parameters()
 | |
|     optim.init(params)
 | |
| 
 | |
|     state = [benchNet.state, optim.state]
 | |
| 
 | |
|     def loss_fn(params, image):
 | |
|         benchNet.update(params)
 | |
|         pred_image = benchNet(image)
 | |
|         return (pred_image - image).abs().mean()
 | |
| 
 | |
|     def step(params, image):
 | |
|         loss, grads = mx.value_and_grad(loss_fn)(params, image)
 | |
|         optim.update(benchNet, grads)
 | |
|         return loss
 | |
| 
 | |
|     total_time = 0.0
 | |
|     print("MLX:")
 | |
|     for i in range(steps):
 | |
|         start_time = time.perf_counter()
 | |
| 
 | |
|         step(benchNet.parameters(), inputs)
 | |
|         mx.eval(state)
 | |
|         end_time = time.perf_counter()
 | |
| 
 | |
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
 | |
|         total_time += (end_time - start_time) * 1000
 | |
| 
 | |
|     return total_time
 | |
| 
 | |
| 
 | |
| def bench_torch(steps: int = 20) -> float:
 | |
|     device = torch.device("cpu")
 | |
| 
 | |
|     class BenchNetTorch(torch.nn.Module):
 | |
|         # simple encoder-decoder net
 | |
| 
 | |
|         def __init__(self, in_channels, hidden_channels=32):
 | |
|             super().__init__()
 | |
| 
 | |
|             self.net = torch.nn.Sequential(
 | |
|                 torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
 | |
|                 torch.nn.ReLU(),
 | |
|                 torch.nn.Conv2d(
 | |
|                     hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|                 torch.nn.ReLU(),
 | |
|                 torch.nn.ConvTranspose2d(
 | |
|                     2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|                 torch.nn.ReLU(),
 | |
|                 torch.nn.ConvTranspose2d(
 | |
|                     hidden_channels, in_channels, kernel_size=3, padding=1
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         def forward(self, input):
 | |
|             return self.net(input)
 | |
| 
 | |
|     benchNet = BenchNetTorch(3).to(device)
 | |
|     optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
 | |
| 
 | |
|     inputs = torch.randn(10, 3, 256, 256, device=device)
 | |
| 
 | |
|     def loss_fn(pred_image, image):
 | |
|         return (pred_image - image).abs().mean()
 | |
| 
 | |
|     total_time = 0.0
 | |
|     print("PyTorch:")
 | |
|     for i in range(steps):
 | |
|         start_time = time.perf_counter()
 | |
| 
 | |
|         optim.zero_grad()
 | |
|         pred_image = benchNet(inputs)
 | |
|         loss = loss_fn(pred_image, inputs)
 | |
|         loss.backward()
 | |
|         optim.step()
 | |
| 
 | |
|         end_time = time.perf_counter()
 | |
| 
 | |
|         print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
 | |
|         total_time += (end_time - start_time) * 1000
 | |
| 
 | |
|     return total_time
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     steps = 20
 | |
|     time_mlx = bench_mlx(steps)
 | |
|     time_torch = bench_torch(steps)
 | |
| 
 | |
|     print(f"average time of MLX:     {time_mlx/steps:9.2f} ms")
 | |
|     print(f"total time of MLX:       {time_mlx:9.2f} ms")
 | |
|     print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
 | |
|     print(f"total time of PyTorch:   {time_torch:9.2f} ms")
 | |
| 
 | |
|     diff = time_torch / time_mlx - 1.0
 | |
|     print(f"torch/mlx diff: {100. * diff:+5.2f}%")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
