mlx/benchmarks/python/single_ops.py

109 lines
2.2 KiB
Python
Raw Normal View History

2023-12-01 03:12:53 +08:00
# Copyright © 2023 Apple Inc.
2023-11-30 02:30:41 +08:00
import argparse
import mlx.core as mx
from time_utils import time_fn
def time_add():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.add, a, b)
aT = mx.transpose(a, [0, 2, 1])
mx.eval(aT)
def transpose_add(a, b):
return mx.add(a, b)
time_fn(transpose_add, aT, b)
b = mx.random.uniform(shape=(1024,))
mx.eval(b)
def slice_add(a, b):
return mx.add(a, b)
time_fn(slice_add, a, b)
b = mx.reshape(b, (1, 1024, 1))
mx.eval(b)
def mid_slice_add(a, b):
return mx.add(a, b)
time_fn(mid_slice_add, a, b)
def time_matmul():
a = mx.random.uniform(shape=(1024, 1024))
b = mx.random.uniform(shape=(1024, 1024))
mx.eval(a, b)
time_fn(mx.matmul, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
def negative(a):
return -a
mx.eval(a)
time_fn(negative, a)
def time_exp():
a = mx.random.uniform(shape=(1000, 100))
mx.eval(a)
time_fn(mx.exp, a)
def time_logsumexp():
a = mx.random.uniform(shape=(64, 10, 10000))
mx.eval(a)
time_fn(mx.logsumexp, a, axis=-1)
def time_take():
a = mx.random.uniform(shape=(10000, 500))
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
ids = [mx.reshape(idx, (-1,)) for idx in ids]
mx.eval(ids)
def random_take():
return [mx.take(a, idx, 0) for idx in ids]
time_fn(random_take)
def time_reshape_transposed():
x = mx.random.uniform(shape=(256, 256, 128))
mx.eval(x)
def reshape_transposed():
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
time_fn(reshape_transposed)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_add()
time_matmul()
time_exp()
time_negative()
time_logsumexp()
time_take()
time_reshape_transposed()