mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Conv cpu improvements (#1410)
This commit is contained in:

committed by
GitHub

parent
d6492b0163
commit
adcc88e208
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal([10, 256, 256, 3])
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(10, 3, 256, 256, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 20
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal(shape)
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(*shape, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 10
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
Reference in New Issue
Block a user