From adcc88e208d03b44c230f3c2c585a97efbb8e687 Mon Sep 17 00:00:00 2001 From: Max-Heinrich Laves <8014859+mlaves@users.noreply.github.com> Date: Mon, 16 Sep 2024 03:45:10 +0200 Subject: [PATCH] Conv cpu improvements (#1410) --- benchmarks/python/conv2d_bench_cpu.py | 127 +++++++++++++ benchmarks/python/conv2d_train_bench_cpu.py | 143 +++++++++++++++ .../python/conv2d_transpose_bench_cpu.py | 129 +++++++++++++ benchmarks/python/conv3d_bench_cpu.py | 110 +++++++++++ benchmarks/python/conv3d_train_bench_cpu.py | 143 +++++++++++++++ .../python/conv3d_transpose_bench_cpu.py | 116 ++++++++++++ mlx/backend/common/conv.cpp | 57 +++++- mlx/backend/common/copy.cpp | 173 ++++++++++++++++++ 8 files changed, 997 insertions(+), 1 deletion(-) create mode 100644 benchmarks/python/conv2d_bench_cpu.py create mode 100644 benchmarks/python/conv2d_train_bench_cpu.py create mode 100644 benchmarks/python/conv2d_transpose_bench_cpu.py create mode 100644 benchmarks/python/conv3d_bench_cpu.py create mode 100644 benchmarks/python/conv3d_train_bench_cpu.py create mode 100644 benchmarks/python/conv3d_transpose_bench_cpu.py diff --git a/benchmarks/python/conv2d_bench_cpu.py b/benchmarks/python/conv2d_bench_cpu.py new file mode 100644 index 000000000..d560ae1b3 --- /dev/null +++ b/benchmarks/python/conv2d_bench_cpu.py @@ -0,0 +1,127 @@ +import argparse +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 1 +N_iter_bench = 10 +N_iter_func = 5 +mx.set_default_device(mx.cpu) + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_2D + + +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + return ys + + return pt_conv_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") + b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu") + + f_mx = make_mx_conv_2D(strides, padding, groups) + f_pt = make_pt_conv_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), + (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), + # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2), + # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16), + # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64), + (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), + ) + + for dtype in dtypes: + print( + "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/conv2d_train_bench_cpu.py b/benchmarks/python/conv2d_train_bench_cpu.py new file mode 100644 index 000000000..d85587909 --- /dev/null +++ b/benchmarks/python/conv2d_train_bench_cpu.py @@ -0,0 +1,143 @@ +import time + +import mlx.core as mx +import mlx.nn +import mlx.optimizers as opt +import torch + + +def bench_mlx(steps: int = 20) -> float: + mx.set_default_device(mx.cpu) + + class BenchNetMLX(mlx.nn.Module): + # simple encoder-decoder net + + def __init__(self, in_channels, hidden_channels=32): + super().__init__() + + self.net = mlx.nn.Sequential( + mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + mlx.nn.ReLU(), + mlx.nn.Conv2d( + hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 + ), + mlx.nn.ReLU(), + mlx.nn.ConvTranspose2d( + 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 + ), + mlx.nn.ReLU(), + mlx.nn.ConvTranspose2d( + hidden_channels, in_channels, kernel_size=3, padding=1 + ), + ) + + def __call__(self, input): + return self.net(input) + + benchNet = BenchNetMLX(3) + mx.eval(benchNet.parameters()) + optim = opt.Adam(learning_rate=1e-3) + + inputs = mx.random.normal([10, 256, 256, 3]) + + params = benchNet.parameters() + optim.init(params) + + state = [benchNet.state, optim.state] + + def loss_fn(params, image): + benchNet.update(params) + pred_image = benchNet(image) + return (pred_image - image).abs().mean() + + def step(params, image): + loss, grads = mx.value_and_grad(loss_fn)(params, image) + optim.update(benchNet, grads) + return loss + + total_time = 0.0 + print("MLX:") + for i in range(steps): + start_time = time.perf_counter() + + step(benchNet.parameters(), inputs) + mx.eval(state) + end_time = time.perf_counter() + + print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") + total_time += (end_time - start_time) * 1000 + + return total_time + + +def bench_torch(steps: int = 20) -> float: + device = torch.device("cpu") + + class BenchNetTorch(torch.nn.Module): + # simple encoder-decoder net + + def __init__(self, in_channels, hidden_channels=32): + super().__init__() + + self.net = torch.nn.Sequential( + torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d( + hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 + ), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d( + 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 + ), + torch.nn.ReLU(), + torch.nn.ConvTranspose2d( + hidden_channels, in_channels, kernel_size=3, padding=1 + ), + ) + + def forward(self, input): + return self.net(input) + + benchNet = BenchNetTorch(3).to(device) + optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) + + inputs = torch.randn(10, 3, 256, 256, device=device) + + def loss_fn(pred_image, image): + return (pred_image - image).abs().mean() + + total_time = 0.0 + print("PyTorch:") + for i in range(steps): + start_time = time.perf_counter() + + optim.zero_grad() + pred_image = benchNet(inputs) + loss = loss_fn(pred_image, inputs) + loss.backward() + optim.step() + + end_time = time.perf_counter() + + print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") + total_time += (end_time - start_time) * 1000 + + return total_time + + +def main(): + steps = 20 + time_mlx = bench_mlx(steps) + time_torch = bench_torch(steps) + + print(f"average time of MLX: {time_mlx/steps:9.2f} ms") + print(f"total time of MLX: {time_mlx:9.2f} ms") + print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") + print(f"total time of PyTorch: {time_torch:9.2f} ms") + + diff = time_torch / time_mlx - 1.0 + print(f"torch/mlx diff: {100. * diff:+5.2f}%") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/python/conv2d_transpose_bench_cpu.py b/benchmarks/python/conv2d_transpose_bench_cpu.py new file mode 100644 index 000000000..28f31f66a --- /dev/null +++ b/benchmarks/python/conv2d_transpose_bench_cpu.py @@ -0,0 +1,129 @@ +import argparse +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 1 +N_iter_bench = 10 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_transpose_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv_transpose2d( + a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu + ) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_transpose_2D + + +def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_transpose_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv_transpose2d( + a, b, stride=strides, padding=padding, groups=groups + ) + ys.append(y) + return ys + + return pt_conv_transpose_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu") + b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu") + + f_mx = make_mx_conv_transpose_2D(strides, padding, groups) + f_pt = make_pt_conv_transpose_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv_transpose2d( + a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu + ) + out_pt = torch.conv_transpose2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), + (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), + ) + + for dtype in dtypes: + print( + "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/conv3d_bench_cpu.py b/benchmarks/python/conv3d_bench_cpu.py new file mode 100644 index 000000000..d3b263da8 --- /dev/null +++ b/benchmarks/python/conv3d_bench_cpu.py @@ -0,0 +1,110 @@ +import argparse +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 1 +N_iter_bench = 10 +N_iter_func = 5 +mx.set_default_device(mx.cpu) + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_3D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_3D + + +def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_3D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + return ys + + return pt_conv_3D + + +def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kD * kH * kW * C) + a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") + b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu") + + f_mx = make_mx_conv_3D(strides, padding, groups) + f_pt = make_pt_conv_3D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv3d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), + (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), + ) + + for dtype in dtypes: + print( + "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/conv3d_train_bench_cpu.py b/benchmarks/python/conv3d_train_bench_cpu.py new file mode 100644 index 000000000..fee99ee76 --- /dev/null +++ b/benchmarks/python/conv3d_train_bench_cpu.py @@ -0,0 +1,143 @@ +import time + +import mlx.core as mx +import mlx.nn +import mlx.optimizers as opt +import torch + + +def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float: + mx.set_default_device(mx.cpu) + + class BenchNetMLX(mlx.nn.Module): + # simple encoder-decoder net + + def __init__(self, in_channels, hidden_channels=16): + super().__init__() + + self.net = mlx.nn.Sequential( + mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), + mlx.nn.ReLU(), + mlx.nn.Conv3d( + hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 + ), + mlx.nn.ReLU(), + mlx.nn.ConvTranspose3d( + 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 + ), + mlx.nn.ReLU(), + mlx.nn.ConvTranspose3d( + hidden_channels, in_channels, kernel_size=3, padding=1 + ), + ) + + def __call__(self, input): + return self.net(input) + + benchNet = BenchNetMLX(3) + mx.eval(benchNet.parameters()) + optim = opt.Adam(learning_rate=1e-3) + + inputs = mx.random.normal(shape) + + params = benchNet.parameters() + optim.init(params) + + state = [benchNet.state, optim.state] + + def loss_fn(params, image): + benchNet.update(params) + pred_image = benchNet(image) + return (pred_image - image).abs().mean() + + def step(params, image): + loss, grads = mx.value_and_grad(loss_fn)(params, image) + optim.update(benchNet, grads) + return loss + + total_time = 0.0 + print("MLX:") + for i in range(steps): + start_time = time.perf_counter() + + step(benchNet.parameters(), inputs) + mx.eval(state) + end_time = time.perf_counter() + + print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") + total_time += (end_time - start_time) * 1000 + + return total_time + + +def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float: + device = torch.device("cpu") + + class BenchNetTorch(torch.nn.Module): + # simple encoder-decoder net + + def __init__(self, in_channels, hidden_channels=16): + super().__init__() + + self.net = torch.nn.Sequential( + torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1), + torch.nn.ReLU(), + torch.nn.Conv3d( + hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1 + ), + torch.nn.ReLU(), + torch.nn.ConvTranspose3d( + 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1 + ), + torch.nn.ReLU(), + torch.nn.ConvTranspose3d( + hidden_channels, in_channels, kernel_size=3, padding=1 + ), + ) + + def forward(self, input): + return self.net(input) + + benchNet = BenchNetTorch(3).to(device) + optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3) + + inputs = torch.randn(*shape, device=device) + + def loss_fn(pred_image, image): + return (pred_image - image).abs().mean() + + total_time = 0.0 + print("PyTorch:") + for i in range(steps): + start_time = time.perf_counter() + + optim.zero_grad() + pred_image = benchNet(inputs) + loss = loss_fn(pred_image, inputs) + loss.backward() + optim.step() + + end_time = time.perf_counter() + + print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms") + total_time += (end_time - start_time) * 1000 + + return total_time + + +def main(): + steps = 10 + time_mlx = bench_mlx(steps) + time_torch = bench_torch(steps) + + print(f"average time of MLX: {time_mlx/steps:9.2f} ms") + print(f"total time of MLX: {time_mlx:9.2f} ms") + print(f"average time of PyTorch: {time_torch/steps:9.2f} ms") + print(f"total time of PyTorch: {time_torch:9.2f} ms") + + diff = time_torch / time_mlx - 1.0 + print(f"torch/mlx diff: {100. * diff:+5.2f}%") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/python/conv3d_transpose_bench_cpu.py b/benchmarks/python/conv3d_transpose_bench_cpu.py new file mode 100644 index 000000000..742d9d46b --- /dev/null +++ b/benchmarks/python/conv3d_transpose_bench_cpu.py @@ -0,0 +1,116 @@ +import argparse +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 1 +N_iter_bench = 10 +N_iter_func = 5 +mx.set_default_device(mx.cpu) + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): + def mx_conv_3D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv_transpose3d( + a, b, stride=strides, padding=padding, groups=groups + ) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_3D + + +def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1): + @torch.no_grad() + def pt_conv_3D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv_transpose3d( + a, b, stride=strides, padding=padding, groups=groups + ) + ys.append(y) + return ys + + return pt_conv_3D + + +def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kD * kH * kW * C) + a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu") + b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu") + + f_mx = make_mx_conv_3D(strides, padding, groups) + f_pt = make_pt_conv_3D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv_transpose3d( + a_mx, b_mx, stride=strides, padding=padding, groups=groups + ) + out_pt = torch.conv_transpose3d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1), + (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1), + ) + + for dtype in dtypes: + print( + "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index e60d3bc9c..76edc9a27 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -684,6 +684,32 @@ void dispatch_slow_conv_3D( // Explicit gemm conv /////////////////////////////////////////////////////////////////////////////// +template +void flip_spatial_dims_inplace(array& wt) { + T* x = wt.data(); + size_t out_channels = wt.shape(0); + size_t in_channels = wt.shape(-1); + + // Calculate the total size of the spatial dimensions + int spatial_size = 1; + for (int d = 1; d < wt.ndim() - 1; ++d) { + spatial_size *= wt.shape(d); + } + + for (size_t i = 0; i < out_channels; i++) { + T* top = x + i * spatial_size * in_channels; + T* bottom = + x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels; + for (size_t j = 0; j < spatial_size / 2; j++) { + for (size_t k = 0; k < in_channels; k++) { + std::swap(top[k], bottom[k]); + } + top += in_channels; + bottom -= in_channels; + } + } +} + void explicit_gemm_conv_1D_cpu( const array& in, const array& wt, @@ -910,7 +936,8 @@ void explicit_gemm_conv_ND_cpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const bool flip) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const auto iDim = std::vector( in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim @@ -1000,6 +1027,14 @@ void explicit_gemm_conv_ND_cpu( copy(wt, gemm_wt, ctype); } + if (flip) { + auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {}); + copy(gemm_wt, gemm_wt_, CopyType::Vector); + + flip_spatial_dims_inplace(gemm_wt_); + gemm_wt = gemm_wt_; + } + if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); @@ -1042,10 +1077,15 @@ void conv_1D_cpu( const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { + const int groups = in.shape().back() / wt.shape().back(); if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( in, wt, out, padding, wt_strides, wt_dilation); } + if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { + return explicit_gemm_conv_ND_cpu( + in, wt, out, padding, wt_strides, wt_dilation, flip); + } return dispatch_slow_conv_1D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); @@ -1060,6 +1100,13 @@ void conv_2D_cpu( const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { + const int groups = in.shape().back() / wt.shape().back(); + if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && + in_dilation[1] == 1 && groups == 1) { + return explicit_gemm_conv_ND_cpu( + in, wt, out, padding, wt_strides, wt_dilation, flip); + } + return dispatch_slow_conv_2D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } @@ -1073,6 +1120,14 @@ void conv_3D_cpu( const std::vector& wt_dilation, const std::vector& in_dilation, bool flip) { + const int groups = in.shape().back() / wt.shape().back(); + if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 && + in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && + groups == 1) { + return explicit_gemm_conv_ND_cpu( + in, wt, out, padding, wt_strides, wt_dilation, flip); + } + return dispatch_slow_conv_3D( in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 34b84258b..ff0d00df5 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -136,6 +136,167 @@ inline void copy_general_dim4(const array& src, array& dst) { src, dst, src.shape(), src.strides(), 0); } +template +void copy_general_dim5( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { + const SrcT* src_ptr = src.data() + i_offset; + DstT* dst_ptr = dst.data(); + + // Pre-compute loop bounds and strides + const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], + d3 = data_shape[3], d4 = data_shape[4]; + const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], + s3 = i_strides[3], s4 = i_strides[4]; + + // Pre-compute stride adjustments + const stride_t s3_adj = s3 - s4 * d4; + const stride_t s2_adj = s2 - s3 * d3; + const stride_t s1_adj = s1 - s2 * d2; + const stride_t s0_adj = s0 - s1 * d1; + + stride_t src_idx = 0; + stride_t dst_idx = 0; + + for (int i = 0; i < d0; ++i) { + for (int j = 0; j < d1; ++j) { + for (int k = 0; k < d2; ++k) { + for (int l = 0; l < d3; ++l) { + for (int m = 0; m < d4; ++m) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += s4; + } + src_idx += s3_adj; + } + src_idx += s2_adj; + } + src_idx += s1_adj; + } + src_idx += s0_adj; + } +} + +template +inline void copy_general_dim5(const array& src, array& dst) { + return copy_general_dim5( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general_dim6( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { + const SrcT* src_ptr = src.data() + i_offset; + DstT* dst_ptr = dst.data(); + + // Pre-compute loop bounds and strides + const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], + d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5]; + const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], + s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5]; + + // Pre-compute stride adjustments + const stride_t s4_adj = s4 - s5 * d5; + const stride_t s3_adj = s3 - s4 * d4; + const stride_t s2_adj = s2 - s3 * d3; + const stride_t s1_adj = s1 - s2 * d2; + const stride_t s0_adj = s0 - s1 * d1; + + stride_t src_idx = 0; + stride_t dst_idx = 0; + + for (int i = 0; i < d0; ++i) { + for (int j = 0; j < d1; ++j) { + for (int k = 0; k < d2; ++k) { + for (int l = 0; l < d3; ++l) { + for (int m = 0; m < d4; ++m) { + for (int n = 0; n < d5; ++n) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += s5; + } + src_idx += s4_adj; + } + src_idx += s3_adj; + } + src_idx += s2_adj; + } + src_idx += s1_adj; + } + src_idx += s0_adj; + } +} + +template +inline void copy_general_dim6(const array& src, array& dst) { + return copy_general_dim6( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general_dim7( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { + const SrcT* src_ptr = src.data() + i_offset; + DstT* dst_ptr = dst.data(); + + // Pre-compute loop bounds and strides + const int d0 = data_shape[0], d1 = data_shape[1], d2 = data_shape[2], + d3 = data_shape[3], d4 = data_shape[4], d5 = data_shape[5], + d6 = data_shape[6]; + const stride_t s0 = i_strides[0], s1 = i_strides[1], s2 = i_strides[2], + s3 = i_strides[3], s4 = i_strides[4], s5 = i_strides[5], + s6 = i_strides[6]; + + // Pre-compute stride adjustments + const stride_t s5_adj = s5 - s6 * d6; + const stride_t s4_adj = s4 - s5 * d5; + const stride_t s3_adj = s3 - s4 * d4; + const stride_t s2_adj = s2 - s3 * d3; + const stride_t s1_adj = s1 - s2 * d2; + const stride_t s0_adj = s0 - s1 * d1; + + stride_t src_idx = 0; + stride_t dst_idx = 0; + + for (int i = 0; i < d0; ++i) { + for (int j = 0; j < d1; ++j) { + for (int k = 0; k < d2; ++k) { + for (int l = 0; l < d3; ++l) { + for (int m = 0; m < d4; ++m) { + for (int n = 0; n < d5; ++n) { + for (int p = 0; p < d6; ++p) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += s6; + } + src_idx += s5_adj; + } + src_idx += s4_adj; + } + src_idx += s3_adj; + } + src_idx += s2_adj; + } + src_idx += s1_adj; + } + src_idx += s0_adj; + } +} + +template +inline void copy_general_dim7(const array& src, array& dst) { + return copy_general_dim7( + src, dst, src.shape(), src.strides(), 0); +} + template void copy_general( const array& src, @@ -162,6 +323,18 @@ void copy_general( copy_general_dim4( src, dst, new_shape, new_strides[0], i_offset); return; + case 5: + copy_general_dim5( + src, dst, new_shape, new_strides[0], i_offset); + return; + case 6: + copy_general_dim6( + src, dst, new_shape, new_strides[0], i_offset); + return; + case 7: + copy_general_dim7( + src, dst, new_shape, new_strides[0], i_offset); + return; } auto src_ptr = src.data() + i_offset;