mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
144 lines
4.1 KiB
Python
144 lines
4.1 KiB
Python
![]() |
import time
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import mlx.nn
|
||
|
import mlx.optimizers as opt
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||
|
mx.set_default_device(mx.cpu)
|
||
|
|
||
|
class BenchNetMLX(mlx.nn.Module):
|
||
|
# simple encoder-decoder net
|
||
|
|
||
|
def __init__(self, in_channels, hidden_channels=16):
|
||
|
super().__init__()
|
||
|
|
||
|
self.net = mlx.nn.Sequential(
|
||
|
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||
|
mlx.nn.ReLU(),
|
||
|
mlx.nn.Conv3d(
|
||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
mlx.nn.ReLU(),
|
||
|
mlx.nn.ConvTranspose3d(
|
||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
mlx.nn.ReLU(),
|
||
|
mlx.nn.ConvTranspose3d(
|
||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def __call__(self, input):
|
||
|
return self.net(input)
|
||
|
|
||
|
benchNet = BenchNetMLX(3)
|
||
|
mx.eval(benchNet.parameters())
|
||
|
optim = opt.Adam(learning_rate=1e-3)
|
||
|
|
||
|
inputs = mx.random.normal(shape)
|
||
|
|
||
|
params = benchNet.parameters()
|
||
|
optim.init(params)
|
||
|
|
||
|
state = [benchNet.state, optim.state]
|
||
|
|
||
|
def loss_fn(params, image):
|
||
|
benchNet.update(params)
|
||
|
pred_image = benchNet(image)
|
||
|
return (pred_image - image).abs().mean()
|
||
|
|
||
|
def step(params, image):
|
||
|
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||
|
optim.update(benchNet, grads)
|
||
|
return loss
|
||
|
|
||
|
total_time = 0.0
|
||
|
print("MLX:")
|
||
|
for i in range(steps):
|
||
|
start_time = time.perf_counter()
|
||
|
|
||
|
step(benchNet.parameters(), inputs)
|
||
|
mx.eval(state)
|
||
|
end_time = time.perf_counter()
|
||
|
|
||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||
|
total_time += (end_time - start_time) * 1000
|
||
|
|
||
|
return total_time
|
||
|
|
||
|
|
||
|
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||
|
device = torch.device("cpu")
|
||
|
|
||
|
class BenchNetTorch(torch.nn.Module):
|
||
|
# simple encoder-decoder net
|
||
|
|
||
|
def __init__(self, in_channels, hidden_channels=16):
|
||
|
super().__init__()
|
||
|
|
||
|
self.net = torch.nn.Sequential(
|
||
|
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||
|
torch.nn.ReLU(),
|
||
|
torch.nn.Conv3d(
|
||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
torch.nn.ReLU(),
|
||
|
torch.nn.ConvTranspose3d(
|
||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
torch.nn.ReLU(),
|
||
|
torch.nn.ConvTranspose3d(
|
||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def forward(self, input):
|
||
|
return self.net(input)
|
||
|
|
||
|
benchNet = BenchNetTorch(3).to(device)
|
||
|
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||
|
|
||
|
inputs = torch.randn(*shape, device=device)
|
||
|
|
||
|
def loss_fn(pred_image, image):
|
||
|
return (pred_image - image).abs().mean()
|
||
|
|
||
|
total_time = 0.0
|
||
|
print("PyTorch:")
|
||
|
for i in range(steps):
|
||
|
start_time = time.perf_counter()
|
||
|
|
||
|
optim.zero_grad()
|
||
|
pred_image = benchNet(inputs)
|
||
|
loss = loss_fn(pred_image, inputs)
|
||
|
loss.backward()
|
||
|
optim.step()
|
||
|
|
||
|
end_time = time.perf_counter()
|
||
|
|
||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||
|
total_time += (end_time - start_time) * 1000
|
||
|
|
||
|
return total_time
|
||
|
|
||
|
|
||
|
def main():
|
||
|
steps = 10
|
||
|
time_mlx = bench_mlx(steps)
|
||
|
time_torch = bench_torch(steps)
|
||
|
|
||
|
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||
|
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||
|
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||
|
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||
|
|
||
|
diff = time_torch / time_mlx - 1.0
|
||
|
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|