mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
109 lines
2.2 KiB
Python
109 lines
2.2 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import argparse
|
|
import mlx.core as mx
|
|
|
|
from time_utils import time_fn
|
|
|
|
|
|
def time_add():
|
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
|
b = mx.random.uniform(shape=(32, 1024, 1024))
|
|
mx.eval(a, b)
|
|
time_fn(mx.add, a, b)
|
|
|
|
aT = mx.transpose(a, [0, 2, 1])
|
|
mx.eval(aT)
|
|
|
|
def transpose_add(a, b):
|
|
return mx.add(a, b)
|
|
|
|
time_fn(transpose_add, aT, b)
|
|
|
|
b = mx.random.uniform(shape=(1024,))
|
|
mx.eval(b)
|
|
|
|
def slice_add(a, b):
|
|
return mx.add(a, b)
|
|
|
|
time_fn(slice_add, a, b)
|
|
|
|
b = mx.reshape(b, (1, 1024, 1))
|
|
mx.eval(b)
|
|
|
|
def mid_slice_add(a, b):
|
|
return mx.add(a, b)
|
|
|
|
time_fn(mid_slice_add, a, b)
|
|
|
|
|
|
def time_matmul():
|
|
a = mx.random.uniform(shape=(1024, 1024))
|
|
b = mx.random.uniform(shape=(1024, 1024))
|
|
mx.eval(a, b)
|
|
time_fn(mx.matmul, a, b)
|
|
|
|
|
|
def time_negative():
|
|
a = mx.random.uniform(shape=(10000, 1000))
|
|
mx.eval(a)
|
|
|
|
def negative(a):
|
|
return -a
|
|
|
|
mx.eval(a)
|
|
|
|
time_fn(negative, a)
|
|
|
|
|
|
def time_exp():
|
|
a = mx.random.uniform(shape=(1000, 100))
|
|
mx.eval(a)
|
|
time_fn(mx.exp, a)
|
|
|
|
|
|
def time_logsumexp():
|
|
a = mx.random.uniform(shape=(64, 10, 10000))
|
|
mx.eval(a)
|
|
time_fn(mx.logsumexp, a, axis=-1)
|
|
|
|
|
|
def time_take():
|
|
a = mx.random.uniform(shape=(10000, 500))
|
|
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
|
|
ids = [mx.reshape(idx, (-1,)) for idx in ids]
|
|
mx.eval(ids)
|
|
|
|
def random_take():
|
|
return [mx.take(a, idx, 0) for idx in ids]
|
|
|
|
time_fn(random_take)
|
|
|
|
|
|
def time_reshape_transposed():
|
|
x = mx.random.uniform(shape=(256, 256, 128))
|
|
mx.eval(x)
|
|
|
|
def reshape_transposed():
|
|
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
|
|
|
|
time_fn(reshape_transposed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser("MLX benchmarks.")
|
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
args = parser.parse_args()
|
|
if args.gpu:
|
|
mx.set_default_device(mx.gpu)
|
|
else:
|
|
mx.set_default_device(mx.cpu)
|
|
|
|
time_add()
|
|
time_matmul()
|
|
time_exp()
|
|
time_negative()
|
|
time_logsumexp()
|
|
time_take()
|
|
time_reshape_transposed()
|