mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			144 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			144 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import time
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import mlx.nn
 | 
						|
import mlx.optimizers as opt
 | 
						|
import torch
 | 
						|
 | 
						|
 | 
						|
def bench_mlx(steps: int = 20) -> float:
 | 
						|
    mx.set_default_device(mx.cpu)
 | 
						|
 | 
						|
    class BenchNetMLX(mlx.nn.Module):
 | 
						|
        # simple encoder-decoder net
 | 
						|
 | 
						|
        def __init__(self, in_channels, hidden_channels=32):
 | 
						|
            super().__init__()
 | 
						|
 | 
						|
            self.net = mlx.nn.Sequential(
 | 
						|
                mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
 | 
						|
                mlx.nn.ReLU(),
 | 
						|
                mlx.nn.Conv2d(
 | 
						|
                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
                mlx.nn.ReLU(),
 | 
						|
                mlx.nn.ConvTranspose2d(
 | 
						|
                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
                mlx.nn.ReLU(),
 | 
						|
                mlx.nn.ConvTranspose2d(
 | 
						|
                    hidden_channels, in_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
            )
 | 
						|
 | 
						|
        def __call__(self, input):
 | 
						|
            return self.net(input)
 | 
						|
 | 
						|
    benchNet = BenchNetMLX(3)
 | 
						|
    mx.eval(benchNet.parameters())
 | 
						|
    optim = opt.Adam(learning_rate=1e-3)
 | 
						|
 | 
						|
    inputs = mx.random.normal([10, 256, 256, 3])
 | 
						|
 | 
						|
    params = benchNet.parameters()
 | 
						|
    optim.init(params)
 | 
						|
 | 
						|
    state = [benchNet.state, optim.state]
 | 
						|
 | 
						|
    def loss_fn(params, image):
 | 
						|
        benchNet.update(params)
 | 
						|
        pred_image = benchNet(image)
 | 
						|
        return (pred_image - image).abs().mean()
 | 
						|
 | 
						|
    def step(params, image):
 | 
						|
        loss, grads = mx.value_and_grad(loss_fn)(params, image)
 | 
						|
        optim.update(benchNet, grads)
 | 
						|
        return loss
 | 
						|
 | 
						|
    total_time = 0.0
 | 
						|
    print("MLX:")
 | 
						|
    for i in range(steps):
 | 
						|
        start_time = time.perf_counter()
 | 
						|
 | 
						|
        step(benchNet.parameters(), inputs)
 | 
						|
        mx.eval(state)
 | 
						|
        end_time = time.perf_counter()
 | 
						|
 | 
						|
        print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
 | 
						|
        total_time += (end_time - start_time) * 1000
 | 
						|
 | 
						|
    return total_time
 | 
						|
 | 
						|
 | 
						|
def bench_torch(steps: int = 20) -> float:
 | 
						|
    device = torch.device("cpu")
 | 
						|
 | 
						|
    class BenchNetTorch(torch.nn.Module):
 | 
						|
        # simple encoder-decoder net
 | 
						|
 | 
						|
        def __init__(self, in_channels, hidden_channels=32):
 | 
						|
            super().__init__()
 | 
						|
 | 
						|
            self.net = torch.nn.Sequential(
 | 
						|
                torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
 | 
						|
                torch.nn.ReLU(),
 | 
						|
                torch.nn.Conv2d(
 | 
						|
                    hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
                torch.nn.ReLU(),
 | 
						|
                torch.nn.ConvTranspose2d(
 | 
						|
                    2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
                torch.nn.ReLU(),
 | 
						|
                torch.nn.ConvTranspose2d(
 | 
						|
                    hidden_channels, in_channels, kernel_size=3, padding=1
 | 
						|
                ),
 | 
						|
            )
 | 
						|
 | 
						|
        def forward(self, input):
 | 
						|
            return self.net(input)
 | 
						|
 | 
						|
    benchNet = BenchNetTorch(3).to(device)
 | 
						|
    optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
 | 
						|
 | 
						|
    inputs = torch.randn(10, 3, 256, 256, device=device)
 | 
						|
 | 
						|
    def loss_fn(pred_image, image):
 | 
						|
        return (pred_image - image).abs().mean()
 | 
						|
 | 
						|
    total_time = 0.0
 | 
						|
    print("PyTorch:")
 | 
						|
    for i in range(steps):
 | 
						|
        start_time = time.perf_counter()
 | 
						|
 | 
						|
        optim.zero_grad()
 | 
						|
        pred_image = benchNet(inputs)
 | 
						|
        loss = loss_fn(pred_image, inputs)
 | 
						|
        loss.backward()
 | 
						|
        optim.step()
 | 
						|
 | 
						|
        end_time = time.perf_counter()
 | 
						|
 | 
						|
        print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
 | 
						|
        total_time += (end_time - start_time) * 1000
 | 
						|
 | 
						|
    return total_time
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
    steps = 20
 | 
						|
    time_mlx = bench_mlx(steps)
 | 
						|
    time_torch = bench_torch(steps)
 | 
						|
 | 
						|
    print(f"average time of MLX:     {time_mlx/steps:9.2f} ms")
 | 
						|
    print(f"total time of MLX:       {time_mlx:9.2f} ms")
 | 
						|
    print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
 | 
						|
    print(f"total time of PyTorch:   {time_torch:9.2f} ms")
 | 
						|
 | 
						|
    diff = time_torch / time_mlx - 1.0
 | 
						|
    print(f"torch/mlx diff: {100. * diff:+5.2f}%")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main()
 |