mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			520 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			520 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| import math
 | |
| import os
 | |
| import time
 | |
| from functools import partial
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| 
 | |
| 
 | |
| def int_or_list(x):
 | |
|     try:
 | |
|         return int(x)
 | |
|     except ValueError:
 | |
|         return [int(xi) for xi in x.split(",")]
 | |
| 
 | |
| 
 | |
| def none_or_list(x):
 | |
|     if x == "":
 | |
|         return None
 | |
|     else:
 | |
|         return [int(xi) for xi in x.split(",")]
 | |
| 
 | |
| 
 | |
| def dtype_from_str(x):
 | |
|     if x == "":
 | |
|         return mx.float32
 | |
|     else:
 | |
|         dt = getattr(mx, x)
 | |
|         if not isinstance(dt, mx.Dtype):
 | |
|             raise ValueError(f"{x} is not an mlx dtype")
 | |
|         return dt
 | |
| 
 | |
| 
 | |
| def bench(f, *args):
 | |
|     for i in range(10):
 | |
|         f(*args)
 | |
| 
 | |
|     s = time.time()
 | |
|     for i in range(100):
 | |
|         f(*args)
 | |
|     e = time.time()
 | |
|     return e - s
 | |
| 
 | |
| 
 | |
| def matmul_square(x):
 | |
|     y = x
 | |
|     for i in range(10):
 | |
|         y = y @ x
 | |
|     mx.eval(y)
 | |
|     return y
 | |
| 
 | |
| 
 | |
| def matmul(x, y):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(x @ y)
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def _quant_matmul(x, w, s, b, transpose, group_size, bits):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(
 | |
|             mx.quantized_matmul(
 | |
|                 x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
 | |
|             )
 | |
|         )
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| quant_matmul = {
 | |
|     "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
 | |
|     "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
 | |
|     "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
 | |
|     "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
 | |
|     "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
 | |
|     "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
 | |
|     "quant_matmul_128_2": partial(
 | |
|         _quant_matmul, transpose=False, group_size=128, bits=2
 | |
|     ),
 | |
|     "quant_matmul_128_4": partial(
 | |
|         _quant_matmul, transpose=False, group_size=128, bits=4
 | |
|     ),
 | |
|     "quant_matmul_128_8": partial(
 | |
|         _quant_matmul, transpose=False, group_size=128, bits=8
 | |
|     ),
 | |
|     "quant_matmul_t_32_2": partial(
 | |
|         _quant_matmul, transpose=True, group_size=32, bits=2
 | |
|     ),
 | |
|     "quant_matmul_t_32_4": partial(
 | |
|         _quant_matmul, transpose=True, group_size=32, bits=4
 | |
|     ),
 | |
|     "quant_matmul_t_32_8": partial(
 | |
|         _quant_matmul, transpose=True, group_size=32, bits=8
 | |
|     ),
 | |
|     "quant_matmul_t_64_2": partial(
 | |
|         _quant_matmul, transpose=True, group_size=64, bits=2
 | |
|     ),
 | |
|     "quant_matmul_t_64_4": partial(
 | |
|         _quant_matmul, transpose=True, group_size=64, bits=4
 | |
|     ),
 | |
|     "quant_matmul_t_64_8": partial(
 | |
|         _quant_matmul, transpose=True, group_size=64, bits=8
 | |
|     ),
 | |
|     "quant_matmul_t_128_2": partial(
 | |
|         _quant_matmul, transpose=True, group_size=128, bits=2
 | |
|     ),
 | |
|     "quant_matmul_t_128_4": partial(
 | |
|         _quant_matmul, transpose=True, group_size=128, bits=4
 | |
|     ),
 | |
|     "quant_matmul_t_128_8": partial(
 | |
|         _quant_matmul, transpose=True, group_size=128, bits=8
 | |
|     ),
 | |
| }
 | |
| 
 | |
| 
 | |
| def conv1d(x, y):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.conv1d(x, y))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def conv2d(x, y):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.conv2d(x, y))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def binary(op, x, y):
 | |
|     for i in range(100):
 | |
|         y = getattr(mx, op)(x, y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def reduction(op, axis, x):
 | |
|     ys = []
 | |
|     for i in range(100):
 | |
|         ys.append(getattr(mx, op)(x, axis=axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def sum_and_add(axis, x, y):
 | |
|     z = x.sum(axis=axis, keepdims=True)
 | |
|     for i in range(50):
 | |
|         z = (z + y).sum(axis=axis, keepdims=True)
 | |
|     mx.eval(z)
 | |
| 
 | |
| 
 | |
| def softmax(axis, x):
 | |
|     ys = []
 | |
|     for i in range(100):
 | |
|         ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
 | |
|         y = ex / mx.sum(ex, axis=axis, keepdims=True)
 | |
|         ys.append(y)
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def softmax_fused(axis, x):
 | |
|     ys = []
 | |
|     for i in range(100):
 | |
|         y = mx.softmax(x, axis=axis)
 | |
|         ys.append(y)
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def relu(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.relu(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def leaky_relu(x: mx.array):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.leaky_relu(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def prelu(x: mx.array):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.prelu(y, mx.ones(1))
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def softplus(x: mx.array):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.softplus(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def mish(x: mx.array):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.mish(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def leaky_relu(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.leaky_relu(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def elu(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.elu(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def relu6(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.relu6(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def softplus(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.softplus(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def celu(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.celu(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def log_sigmoid(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.log_sigmoid(y)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def scalar_mult(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = y * (1.0 / (1 + i))
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def cross_entropy(targets, x):
 | |
|     ys = []
 | |
|     for i in range(100):
 | |
|         y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
 | |
|             x, mx.reshape(targets, (-1, 1)), axis=-1
 | |
|         )
 | |
|         ys.append(mx.mean(y))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def logsumexp(axis, x):
 | |
|     ys = []
 | |
|     for i in range(100):
 | |
|         ys.append(mx.logsumexp(x, axis=axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def linear(w, b, x):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(x @ mx.transpose(w, (1, 0)) + b)
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def linear_fused(w, b, x):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def rope(x):
 | |
|     *_, N, D = x.shape
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         shape = x.shape
 | |
|         x = mx.reshape(x, (-1, N, D))
 | |
|         positions = mx.arange(N)
 | |
|         freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
 | |
|         theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
 | |
|         costheta = mx.cos(theta)
 | |
|         sintheta = mx.sin(theta)
 | |
|         x1 = x[..., ::2]
 | |
|         x2 = x[..., 1::2]
 | |
|         rx1 = x1 * costheta - x2 * sintheta
 | |
|         rx2 = x1 * sintheta + x2 * costheta
 | |
|         y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
 | |
|         y = mx.reshape(y, (-1, N, D))
 | |
|         ys.append(y)
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def concatenate(axis, x, y):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.concatenate([x, y], axis=axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def cumsum(axis, x):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.cumsum(x, axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def sort(axis, x):
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.sort(x, axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def topk(axis, x):
 | |
|     k = x.shape[axis] // 3
 | |
|     ys = []
 | |
|     for i in range(10):
 | |
|         ys.append(mx.topk(x, k, axis))
 | |
|     mx.eval(ys)
 | |
| 
 | |
| 
 | |
| def step_function(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.step(x)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| def selu(x):
 | |
|     y = x
 | |
|     for i in range(100):
 | |
|         y = nn.selu(x)
 | |
|     mx.eval(y)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument("benchmark", help="Choose the benchmark to run")
 | |
|     parser.add_argument(
 | |
|         "--size",
 | |
|         default=[(1024, 1024)],
 | |
|         type=lambda x: list(map(int, x.split("x"))),
 | |
|         help="Set the matrix size",
 | |
|         action="append",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--axis",
 | |
|         default=[1],
 | |
|         type=int_or_list,
 | |
|         help="Set a reduction axis",
 | |
|         action="append",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--transpose",
 | |
|         type=none_or_list,
 | |
|         default=[],
 | |
|         help="Permute the matrix",
 | |
|         action="append",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--print-pid", action="store_true", help="Print the PID and pause"
 | |
|     )
 | |
|     parser.add_argument("--cpu", action="store_true", help="Use the CPU")
 | |
|     parser.add_argument(
 | |
|         "--fused", action="store_true", help="Use fused functions where possible"
 | |
|     )
 | |
|     parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     if len(args.size) > 1:
 | |
|         args.size.pop(0)
 | |
|     if len(args.axis) > 1:
 | |
|         args.axis.pop(0)
 | |
| 
 | |
|     if args.cpu:
 | |
|         mx.set_default_device(mx.cpu)
 | |
|     else:
 | |
|         mx.set_default_device(mx.gpu)
 | |
| 
 | |
|     types = args.dtype
 | |
|     if not types:
 | |
|         types = [mx.float32]
 | |
|     if len(types) < len(args.size):
 | |
|         types = types + [types[0]] * (len(args.size) - len(types))
 | |
| 
 | |
|     xs = []
 | |
|     for size, dtype in zip(args.size, types):
 | |
|         xs.append(mx.random.normal(size).astype(dtype))
 | |
|     for i, t in enumerate(args.transpose):
 | |
|         if t is None:
 | |
|             continue
 | |
|         xs[i] = mx.transpose(xs[i], t)
 | |
|     mx.eval(xs)
 | |
|     x = xs[0]
 | |
|     axis = args.axis[0]
 | |
| 
 | |
|     if args.print_pid:
 | |
|         print(os.getpid())
 | |
|         input("Press enter to run")
 | |
| 
 | |
|     if args.benchmark == "matmul_square":
 | |
|         print(bench(matmul_square, x))
 | |
| 
 | |
|     elif args.benchmark == "matmul":
 | |
|         print(bench(matmul, *xs))
 | |
| 
 | |
|     elif args.benchmark.startswith("quant_matmul"):
 | |
|         print(bench(quant_matmul[args.benchmark], *xs))
 | |
| 
 | |
|     elif args.benchmark == "linear":
 | |
|         if args.fused:
 | |
|             print(bench(linear_fused, *xs))
 | |
|         else:
 | |
|             print(bench(linear, *xs))
 | |
| 
 | |
|     elif args.benchmark == "sum_axis":
 | |
|         print(bench(reduction, "sum", axis, x))
 | |
| 
 | |
|     elif args.benchmark == "sum_all":
 | |
|         print(bench(reduction, "sum", None, x))
 | |
| 
 | |
|     elif args.benchmark == "argmax":
 | |
|         print(bench(reduction, "argmax", axis, x))
 | |
| 
 | |
|     elif args.benchmark == "add":
 | |
|         print(bench(binary, "add", *xs))
 | |
| 
 | |
|     elif args.benchmark == "mul":
 | |
|         print(bench(binary, "multiply", *xs))
 | |
| 
 | |
|     elif args.benchmark == "softmax":
 | |
|         if args.fused:
 | |
|             print(bench(softmax_fused, axis, x))
 | |
|         else:
 | |
|             print(bench(softmax, axis, x))
 | |
| 
 | |
|     elif args.benchmark == "relu":
 | |
|         print(bench(relu, x))
 | |
| 
 | |
|     elif args.benchmark == "elu":
 | |
|         print(bench(elu, x))
 | |
| 
 | |
|     elif args.benchmark == "relu6":
 | |
|         print(bench(relu6, x))
 | |
| 
 | |
|     elif args.benchmark == "celu":
 | |
|         print(bench(celu, x))
 | |
| 
 | |
|     elif args.benchmark == "log_sigmoid":
 | |
|         print(bench(log_sigmoid, x))
 | |
| 
 | |
|     elif args.benchmark == "leaky_relu":
 | |
|         print(bench(leaky_relu, x))
 | |
|     elif args.benchmark == "prelu":
 | |
|         print(bench(prelu, x))
 | |
|     elif args.benchmark == "softplus":
 | |
|         print(bench(softplus, x))
 | |
|     elif args.benchmark == "mish":
 | |
|         print(bench(mish, x))
 | |
|     elif args.benchmark == "scalar_mul":
 | |
|         print(bench(scalar_mult, x))
 | |
| 
 | |
|     elif args.benchmark == "cross_entropy":
 | |
|         if len(size) != 2:
 | |
|             raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
 | |
| 
 | |
|         targets = mx.zeros((len(x),), dtype=mx.uint32)
 | |
|         print(bench(cross_entropy, targets, x))
 | |
| 
 | |
|     elif args.benchmark == "logsumexp":
 | |
|         print(bench(logsumexp, axis, x))
 | |
| 
 | |
|     elif args.benchmark == "rope":
 | |
|         print(bench(rope, x))
 | |
| 
 | |
|     elif args.benchmark == "concatenate":
 | |
|         print(bench(concatenate, axis, *xs))
 | |
| 
 | |
|     elif args.benchmark == "cumsum":
 | |
|         print(bench(cumsum, axis, *xs))
 | |
| 
 | |
|     elif args.benchmark == "conv1d":
 | |
|         print(bench(conv1d, *xs))
 | |
| 
 | |
|     elif args.benchmark == "conv2d":
 | |
|         print(bench(conv2d, *xs))
 | |
| 
 | |
|     elif args.benchmark == "sort":
 | |
|         print(bench(sort, axis, x))
 | |
| 
 | |
|     elif args.benchmark == "topk":
 | |
|         print(bench(topk, axis, x))
 | |
| 
 | |
|     elif args.benchmark == "step":
 | |
|         print(bench(step_function, x))
 | |
| 
 | |
|     elif args.benchmark == "selu":
 | |
|         print(bench(selu, x))
 | |
| 
 | |
|     elif args.benchmark == "sum_and_add":
 | |
|         print(bench(sum_and_add, axis, *xs))
 | |
| 
 | |
|     else:
 | |
|         raise ValueError("Unknown benchmark")
 | 
