2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.mps
|
|
|
|
|
|
|
|
|
|
|
|
def int_or_list(x):
|
|
|
|
try:
|
|
|
|
return int(x)
|
|
|
|
except ValueError:
|
|
|
|
return [int(xi) for xi in x.split(",")]
|
|
|
|
|
|
|
|
|
|
|
|
def none_or_list(x):
|
|
|
|
if x == "":
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
return [int(xi) for xi in x.split(",")]
|
|
|
|
|
|
|
|
|
2023-12-19 15:18:57 +08:00
|
|
|
def dtype_from_str(x):
|
|
|
|
if x == "":
|
|
|
|
return torch.float32
|
|
|
|
else:
|
|
|
|
dt = getattr(torch, x)
|
|
|
|
if not isinstance(dt, torch.dtype):
|
|
|
|
raise ValueError(f"{x} is not a torch dtype")
|
|
|
|
return dt
|
|
|
|
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
def bench(f, *args):
|
|
|
|
for i in range(10):
|
|
|
|
f(*args)
|
|
|
|
|
|
|
|
s = time.time()
|
|
|
|
for i in range(100):
|
|
|
|
f(*args)
|
|
|
|
e = time.time()
|
|
|
|
return e - s
|
|
|
|
|
|
|
|
|
|
|
|
def sync_if_needed(x):
|
|
|
|
if x.device != torch.device("cpu"):
|
|
|
|
torch.mps.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def matmul_square(x):
|
|
|
|
y = x
|
|
|
|
for i in range(10):
|
|
|
|
y = y @ x
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def matmul(x, y):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(x @ y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def conv1d(x, y):
|
|
|
|
x = torch.transpose(x, -1, -2)
|
|
|
|
y = torch.transpose(y, -1, -2)
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.nn.functional.conv1d(x, y))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def conv2d(x, y):
|
|
|
|
x = torch.permute(x, (0, 3, 1, 2))
|
|
|
|
y = torch.permute(y, (0, 3, 1, 2))
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.nn.functional.conv2d(x, y))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def binary(op, x, y):
|
|
|
|
for i in range(100):
|
|
|
|
y = getattr(torch, op)(x, y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def reduction(op, axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(100):
|
|
|
|
ys.append(getattr(x, op)(axis))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def softmax(axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(100):
|
|
|
|
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
|
|
|
|
y = ex / torch.sum(ex, dim=axis, keepdims=True)
|
|
|
|
ys.append(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def softmax_fused(axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(100):
|
|
|
|
ys.append(torch.nn.functional.softmax(x, dim=axis))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def relu(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.relu(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2023-12-11 08:31:38 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def leaky_relu(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.leaky_relu(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def elu(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.elu(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def celu(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.celu(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def relu6(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.relu6(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
2025-06-15 17:35:33 +08:00
|
|
|
|
2025-06-15 17:34:10 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def relu_squared(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.relu(y)
|
|
|
|
y = torch.square(y)
|
|
|
|
sync_if_needed(x)
|
2023-12-11 08:31:38 +08:00
|
|
|
|
2025-06-15 17:35:33 +08:00
|
|
|
|
2023-12-11 08:31:38 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def softplus(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.softplus(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def log_sigmoid(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.logsigmoid(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2023-12-12 11:40:57 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def prelu(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
y = x
|
|
|
|
for _ in range(100):
|
|
|
|
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
y = x
|
|
|
|
for _ in range(100):
|
2024-06-04 22:50:46 +08:00
|
|
|
y = torch.nn.functional.mish(y)
|
2023-12-12 11:40:57 +08:00
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def scalar_mult(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = y * (1.0 / (1 + i))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def cross_entropy(targets, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(100):
|
|
|
|
ys.append(torch.nn.functional.cross_entropy(x, targets))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def logsumexp(axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(100):
|
|
|
|
ys.append(torch.logsumexp(x, dim=axis))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def linear_fused(w, b, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.nn.functional.linear(x, w, b))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def linear(w, b, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append((x @ torch.transpose(w, -2, -1)) + b)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def rope(x):
|
|
|
|
*_, N, D = x.shape
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
x = x.view(-1, N, D)
|
|
|
|
positions = torch.arange(N, device=x.device)
|
|
|
|
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
|
|
|
|
theta = positions[:, None] * freqs[None]
|
|
|
|
costheta = torch.cos(theta)
|
|
|
|
sintheta = torch.sin(theta)
|
|
|
|
x1 = x[..., ::2]
|
|
|
|
x2 = x[..., 1::2]
|
|
|
|
rx1 = x1 * costheta - x2 * sintheta
|
|
|
|
rx2 = x1 * sintheta + x2 * costheta
|
|
|
|
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
|
|
|
y = y.reshape(-1, N, D)
|
|
|
|
ys.append(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def concatenate(axis, x, y):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.cat([x, y], dim=axis))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def cumsum(axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(x.cumsum(axis))
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def sort(axis, x):
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.sort(x, dim=axis)[0])
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def topk(axis, x):
|
|
|
|
k = x.shape[axis] // 3
|
|
|
|
ys = []
|
|
|
|
for i in range(10):
|
|
|
|
ys.append(torch.topk(x, k, dim=axis)[0])
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2024-06-04 22:50:46 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def step_function(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.where(y < 0, 0, 1)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2023-12-12 09:04:07 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
def selu(x):
|
|
|
|
y = x
|
|
|
|
for i in range(100):
|
|
|
|
y = torch.nn.functional.selu(y)
|
|
|
|
sync_if_needed(x)
|
|
|
|
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
|
|
|
parser.add_argument(
|
|
|
|
"--size",
|
|
|
|
default=[(1024, 1024)],
|
|
|
|
type=lambda x: list(map(int, x.split("x"))),
|
|
|
|
help="Set the matrix size",
|
|
|
|
action="append",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--axis",
|
|
|
|
default=[1],
|
|
|
|
type=int_or_list,
|
|
|
|
help="Set a reduction axis",
|
|
|
|
action="append",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--transpose",
|
|
|
|
type=none_or_list,
|
|
|
|
default=[],
|
|
|
|
help="Permute the matrix",
|
|
|
|
action="append",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--print-pid", action="store_true", help="Print the PID and pause"
|
|
|
|
)
|
|
|
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
|
|
|
parser.add_argument(
|
|
|
|
"--fused", action="store_true", help="Use fused functions where possible"
|
|
|
|
)
|
2023-12-19 15:18:57 +08:00
|
|
|
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if len(args.size) > 1:
|
|
|
|
args.size.pop(0)
|
|
|
|
if len(args.axis) > 1:
|
|
|
|
args.axis.pop(0)
|
|
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
|
|
device = "cpu" if args.cpu else "mps"
|
2023-12-19 15:18:57 +08:00
|
|
|
|
|
|
|
types = args.dtype
|
|
|
|
if not types:
|
|
|
|
types = [torch.float32]
|
|
|
|
if len(types) < len(args.size):
|
|
|
|
types = types + [types[0]] * (len(args.size) - len(types))
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
xs = []
|
2023-12-19 15:18:57 +08:00
|
|
|
for size, dtype in zip(args.size, types):
|
2023-11-30 02:30:41 +08:00
|
|
|
xs.append(torch.randn(*size).to(device).to(dtype))
|
|
|
|
for i, t in enumerate(args.transpose):
|
|
|
|
if t is None:
|
|
|
|
continue
|
|
|
|
xs[i] = xs[i].permute(*t)
|
|
|
|
x = xs[0]
|
|
|
|
axis = args.axis[0]
|
|
|
|
|
2024-03-05 11:09:51 +08:00
|
|
|
if args.print_pid:
|
|
|
|
print(os.getpid())
|
|
|
|
input("Press enter to run")
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
if args.benchmark == "matmul_square":
|
|
|
|
print(bench(matmul_square, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "matmul":
|
|
|
|
print(bench(matmul, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "linear":
|
|
|
|
if args.fused:
|
|
|
|
print(bench(linear_fused, *xs))
|
|
|
|
else:
|
|
|
|
print(bench(linear, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "sum_axis":
|
|
|
|
print(bench(reduction, "sum", axis, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "sum_all":
|
|
|
|
print(bench(reduction, "sum", None, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "argmax":
|
|
|
|
print(bench(reduction, "argmax", axis, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "add":
|
|
|
|
print(bench(binary, "add", *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "mul":
|
|
|
|
print(bench(binary, "mul", *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "softmax":
|
|
|
|
if args.fused:
|
|
|
|
print(bench(softmax_fused, axis, x))
|
|
|
|
else:
|
|
|
|
print(bench(softmax, axis, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "relu":
|
|
|
|
print(bench(relu, x))
|
|
|
|
|
2023-12-11 08:31:38 +08:00
|
|
|
elif args.benchmark == "leaky_relu":
|
|
|
|
print(bench(leaky_relu, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "elu":
|
|
|
|
print(bench(elu, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "relu6":
|
|
|
|
print(bench(relu6, x))
|
|
|
|
|
2025-06-15 17:34:10 +08:00
|
|
|
elif args.benchmark == "relu_squared":
|
|
|
|
print(bench(relu_squared, x))
|
|
|
|
|
2023-12-11 08:31:38 +08:00
|
|
|
elif args.benchmark == "softplus":
|
|
|
|
print(bench(softplus, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "celu":
|
|
|
|
print(bench(celu, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "log_sigmoid":
|
|
|
|
print(bench(log_sigmoid, x))
|
|
|
|
|
2023-12-12 11:40:57 +08:00
|
|
|
elif args.benchmark == "prelu":
|
|
|
|
print(bench(prelu, x))
|
|
|
|
elif args.benchmark == "mish":
|
|
|
|
print(bench(mish, x))
|
2023-11-30 02:30:41 +08:00
|
|
|
elif args.benchmark == "scalar_mul":
|
|
|
|
print(bench(scalar_mult, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "cross_entropy":
|
|
|
|
if len(size) != 2:
|
|
|
|
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
|
|
|
|
|
|
|
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
|
|
|
|
print(bench(cross_entropy, targets, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "logsumexp":
|
|
|
|
print(bench(logsumexp, axis, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "rope":
|
|
|
|
print(bench(rope, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "concatenate":
|
|
|
|
print(bench(concatenate, axis, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "cumsum":
|
|
|
|
print(bench(cumsum, axis, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "conv1d":
|
|
|
|
print(bench(conv1d, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "conv2d":
|
|
|
|
print(bench(conv2d, *xs))
|
|
|
|
|
|
|
|
elif args.benchmark == "sort":
|
|
|
|
print(bench(sort, axis, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "topk":
|
|
|
|
print(bench(topk, axis, x))
|
|
|
|
|
2024-06-04 22:50:46 +08:00
|
|
|
elif args.benchmark == "step":
|
|
|
|
print(bench(step_function, x))
|
|
|
|
|
|
|
|
elif args.benchmark == "selu":
|
|
|
|
print(bench(selu, x))
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
else:
|
2024-06-04 22:50:46 +08:00
|
|
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|