mlx/benchmarks/python/comparative/bench_mlx.py

530 lines
12 KiB
Python
Raw Normal View History

2023-12-01 03:12:53 +08:00
# Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
import argparse
import math
import os
import time
from functools import partial
2023-11-30 02:30:41 +08:00
import mlx.core as mx
import mlx.nn as nn
2023-11-30 02:30:41 +08:00
def int_or_list(x):
try:
return int(x)
except ValueError:
return [int(xi) for xi in x.split(",")]
def none_or_list(x):
if x == "":
return None
else:
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
2023-11-30 02:30:41 +08:00
def bench(f, *args):
for i in range(10):
f(*args)
s = time.time()
for i in range(100):
f(*args)
e = time.time()
return e - s
def matmul_square(x):
y = x
for i in range(10):
y = y @ x
mx.eval(y)
return y
def matmul(x, y):
ys = []
for i in range(10):
ys.append(x @ y)
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
}
2023-11-30 02:30:41 +08:00
def conv1d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv1d(x, y))
mx.eval(ys)
def conv2d(x, y):
ys = []
for i in range(10):
ys.append(mx.conv2d(x, y))
mx.eval(ys)
def binary(op, x, y):
for i in range(100):
y = getattr(mx, op)(x, y)
mx.eval(y)
def reduction(op, axis, x):
ys = []
for i in range(100):
ys.append(getattr(mx, op)(x, axis=axis))
mx.eval(ys)
2024-11-05 14:25:16 +08:00
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
2023-11-30 02:30:41 +08:00
def softmax(axis, x):
ys = []
for i in range(100):
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
y = ex / mx.sum(ex, axis=axis, keepdims=True)
ys.append(y)
mx.eval(ys)
def softmax_fused(axis, x):
ys = []
for i in range(100):
y = mx.softmax(x, axis=axis)
ys.append(y)
mx.eval(ys)
def relu(x):
y = x
for i in range(100):
y = nn.relu(y)
mx.eval(y)
def leaky_relu(x: mx.array):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def prelu(x: mx.array):
y = x
for i in range(100):
y = nn.prelu(y, mx.ones(1))
mx.eval(y)
def softplus(x: mx.array):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def mish(x: mx.array):
y = x
for i in range(100):
y = nn.mish(y)
2023-11-30 02:30:41 +08:00
mx.eval(y)
def leaky_relu(x):
y = x
for i in range(100):
y = nn.leaky_relu(y)
mx.eval(y)
def elu(x):
y = x
for i in range(100):
y = nn.elu(y)
mx.eval(y)
def relu6(x):
y = x
for i in range(100):
y = nn.relu6(y)
mx.eval(y)
2025-06-15 17:35:33 +08:00
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
2025-06-15 17:35:33 +08:00
def softplus(x):
y = x
for i in range(100):
y = nn.softplus(y)
mx.eval(y)
def celu(x):
y = x
for i in range(100):
y = nn.celu(y)
mx.eval(y)
def log_sigmoid(x):
y = x
for i in range(100):
y = nn.log_sigmoid(y)
mx.eval(y)
2023-11-30 02:30:41 +08:00
def scalar_mult(x):
y = x
for i in range(100):
y = y * (1.0 / (1 + i))
mx.eval(y)
def cross_entropy(targets, x):
ys = []
for i in range(100):
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
x, mx.reshape(targets, (-1, 1)), axis=-1
)
ys.append(mx.mean(y))
mx.eval(ys)
def logsumexp(axis, x):
ys = []
for i in range(100):
ys.append(mx.logsumexp(x, axis=axis))
mx.eval(ys)
def linear(w, b, x):
ys = []
for i in range(10):
ys.append(x @ mx.transpose(w, (1, 0)) + b)
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
2023-11-30 02:30:41 +08:00
def rope(x):
*_, N, D = x.shape
ys = []
for i in range(10):
shape = x.shape
x = mx.reshape(x, (-1, N, D))
positions = mx.arange(N)
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
costheta = mx.cos(theta)
sintheta = mx.sin(theta)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
y = mx.reshape(y, (-1, N, D))
ys.append(y)
mx.eval(ys)
def concatenate(axis, x, y):
ys = []
for i in range(10):
ys.append(mx.concatenate([x, y], axis=axis))
mx.eval(ys)
def cumsum(axis, x):
ys = []
for i in range(10):
ys.append(mx.cumsum(x, axis))
mx.eval(ys)
def sort(axis, x):
ys = []
for i in range(10):
ys.append(mx.sort(x, axis))
mx.eval(ys)
def topk(axis, x):
k = x.shape[axis] // 3
ys = []
for i in range(10):
ys.append(mx.topk(x, k, axis))
mx.eval(ys)
def step_function(x):
y = x
for i in range(100):
y = nn.step(x)
mx.eval(y)
def selu(x):
y = x
for i in range(100):
y = nn.selu(x)
mx.eval(y)
2023-11-30 02:30:41 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", help="Choose the benchmark to run")
parser.add_argument(
"--size",
default=[(1024, 1024)],
type=lambda x: list(map(int, x.split("x"))),
help="Set the matrix size",
action="append",
)
parser.add_argument(
"--axis",
default=[1],
type=int_or_list,
help="Set a reduction axis",
action="append",
)
parser.add_argument(
"--transpose",
type=none_or_list,
default=[],
help="Permute the matrix",
action="append",
)
parser.add_argument(
"--print-pid", action="store_true", help="Print the PID and pause"
)
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
2023-11-30 02:30:41 +08:00
args = parser.parse_args()
if len(args.size) > 1:
args.size.pop(0)
if len(args.axis) > 1:
args.axis.pop(0)
if args.cpu:
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
2023-11-30 02:30:41 +08:00
xs = []
for size, dtype in zip(args.size, types):
2023-11-30 02:30:41 +08:00
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
continue
xs[i] = mx.transpose(xs[i], t)
mx.eval(xs)
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
2023-11-30 02:30:41 +08:00
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
2023-11-30 02:30:41 +08:00
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
2023-11-30 02:30:41 +08:00
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
elif args.benchmark == "sum_all":
print(bench(reduction, "sum", None, x))
elif args.benchmark == "argmax":
print(bench(reduction, "argmax", axis, x))
elif args.benchmark == "add":
print(bench(binary, "add", *xs))
elif args.benchmark == "mul":
print(bench(binary, "multiply", *xs))
elif args.benchmark == "softmax":
if args.fused:
print(bench(softmax_fused, axis, x))
else:
print(bench(softmax, axis, x))
elif args.benchmark == "relu":
print(bench(relu, x))
elif args.benchmark == "elu":
print(bench(elu, x))
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu":
print(bench(celu, x))
elif args.benchmark == "log_sigmoid":
print(bench(log_sigmoid, x))
elif args.benchmark == "leaky_relu":
print(bench(leaky_relu, x))
elif args.benchmark == "prelu":
print(bench(prelu, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))
elif args.benchmark == "mish":
print(bench(mish, x))
2023-11-30 02:30:41 +08:00
elif args.benchmark == "scalar_mul":
print(bench(scalar_mult, x))
elif args.benchmark == "cross_entropy":
if len(size) != 2:
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
targets = mx.zeros((len(x),), dtype=mx.uint32)
print(bench(cross_entropy, targets, x))
elif args.benchmark == "logsumexp":
print(bench(logsumexp, axis, x))
elif args.benchmark == "rope":
print(bench(rope, x))
elif args.benchmark == "concatenate":
print(bench(concatenate, axis, *xs))
elif args.benchmark == "cumsum":
print(bench(cumsum, axis, *xs))
elif args.benchmark == "conv1d":
print(bench(conv1d, *xs))
elif args.benchmark == "conv2d":
print(bench(conv2d, *xs))
elif args.benchmark == "sort":
print(bench(sort, axis, x))
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
2024-11-05 14:25:16 +08:00
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
2023-11-30 02:30:41 +08:00
else:
raise ValueError("Unknown benchmark")