mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
124 lines
3.6 KiB
Python
124 lines
3.6 KiB
Python
![]() |
import argparse
|
||
|
import math
|
||
|
import os
|
||
|
import subprocess
|
||
|
import time
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||
|
|
||
|
N_warmup = 10
|
||
|
N_iter_bench = 100
|
||
|
N_iter_func = 5
|
||
|
|
||
|
|
||
|
def bench(f, a, b):
|
||
|
for i in range(N_warmup):
|
||
|
f(a, b)
|
||
|
torch.mps.synchronize()
|
||
|
|
||
|
s = time.perf_counter_ns()
|
||
|
for i in range(N_iter_bench):
|
||
|
f(a, b)
|
||
|
e = time.perf_counter_ns()
|
||
|
return (e - s) * 1e-9
|
||
|
|
||
|
|
||
|
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||
|
def mx_conv_1D(a, b):
|
||
|
ys = []
|
||
|
for _ in range(N_iter_func):
|
||
|
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||
|
ys.append(y)
|
||
|
mx.eval(ys)
|
||
|
return ys
|
||
|
|
||
|
return mx_conv_1D
|
||
|
|
||
|
|
||
|
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||
|
@torch.no_grad()
|
||
|
def pt_conv_1D(a, b):
|
||
|
ys = []
|
||
|
for _ in range(N_iter_func):
|
||
|
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||
|
ys.append(y)
|
||
|
torch.mps.synchronize()
|
||
|
return ys
|
||
|
|
||
|
return pt_conv_1D
|
||
|
|
||
|
|
||
|
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||
|
scale = 1.0 / math.sqrt(wH * C)
|
||
|
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||
|
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||
|
|
||
|
a_mx = mx.array(a_np)
|
||
|
b_mx = mx.array(b_np)
|
||
|
|
||
|
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||
|
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||
|
|
||
|
torch.mps.synchronize()
|
||
|
|
||
|
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||
|
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||
|
|
||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||
|
|
||
|
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||
|
out_pt = torch.conv1d(
|
||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||
|
)
|
||
|
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||
|
out_pt = out_pt.numpy(force=True)
|
||
|
|
||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||
|
|
||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||
|
print(
|
||
|
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||
|
)
|
||
|
|
||
|
return time_mlx, time_torch
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||
|
|
||
|
dtypes = ("float32",)
|
||
|
shapes = (
|
||
|
(4, 32, 32, 5, 32, 1, 2, 1),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 2),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 4),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 16),
|
||
|
(4, 32, 32, 5, 32, 1, 2, 32),
|
||
|
(4, 32, 256, 5, 512, 1, 2, 2),
|
||
|
(4, 32, 256, 5, 512, 1, 2, 128),
|
||
|
(4, 32, 256, 5, 512, 1, 2, 256),
|
||
|
)
|
||
|
|
||
|
for dtype in dtypes:
|
||
|
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||
|
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||
|
np_dtype = getattr(np, dtype)
|
||
|
time_mlx, time_torch = bench_shape(
|
||
|
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||
|
)
|
||
|
diff = time_torch / time_mlx - 1.0
|
||
|
|
||
|
print(
|
||
|
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||
|
)
|
||
|
|
||
|
if time_mlx >= 2.0 * time_torch:
|
||
|
print("ATTENTION ^^^^^^^")
|