mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
awni's commit files
This commit is contained in:
106
benchmarks/python/single_ops.py
Normal file
106
benchmarks/python/single_ops.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_add():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.add, a, b)
|
||||
|
||||
aT = mx.transpose(a, [0, 2, 1])
|
||||
mx.eval(aT)
|
||||
|
||||
def transpose_add(a, b):
|
||||
return mx.add(a, b)
|
||||
|
||||
time_fn(transpose_add, aT, b)
|
||||
|
||||
b = mx.random.uniform(shape=(1024,))
|
||||
mx.eval(b)
|
||||
|
||||
def slice_add(a, b):
|
||||
return mx.add(a, b)
|
||||
|
||||
time_fn(slice_add, a, b)
|
||||
|
||||
b = mx.reshape(b, (1, 1024, 1))
|
||||
mx.eval(b)
|
||||
|
||||
def mid_slice_add(a, b):
|
||||
return mx.add(a, b)
|
||||
|
||||
time_fn(mid_slice_add, a, b)
|
||||
|
||||
|
||||
def time_matmul():
|
||||
a = mx.random.uniform(shape=(1024, 1024))
|
||||
b = mx.random.uniform(shape=(1024, 1024))
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
|
||||
def negative(a):
|
||||
return -a
|
||||
|
||||
mx.eval(a)
|
||||
|
||||
time_fn(negative, a)
|
||||
|
||||
|
||||
def time_exp():
|
||||
a = mx.random.uniform(shape=(1000, 100))
|
||||
mx.eval(a)
|
||||
time_fn(mx.exp, a)
|
||||
|
||||
|
||||
def time_logsumexp():
|
||||
a = mx.random.uniform(shape=(64, 10, 10000))
|
||||
mx.eval(a)
|
||||
time_fn(mx.logsumexp, a, axis=-1)
|
||||
|
||||
|
||||
def time_take():
|
||||
a = mx.random.uniform(shape=(10000, 500))
|
||||
ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
|
||||
ids = [mx.reshape(idx, (-1,)) for idx in ids]
|
||||
mx.eval(ids)
|
||||
|
||||
def random_take():
|
||||
return [mx.take(a, idx, 0) for idx in ids]
|
||||
|
||||
time_fn(random_take)
|
||||
|
||||
|
||||
def time_reshape_transposed():
|
||||
x = mx.random.uniform(shape=(256, 256, 128))
|
||||
mx.eval(x)
|
||||
|
||||
def reshape_transposed():
|
||||
return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
|
||||
|
||||
time_fn(reshape_transposed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_exp()
|
||||
time_negative()
|
||||
time_logsumexp()
|
||||
time_take()
|
||||
time_reshape_transposed()
|
||||
Reference in New Issue
Block a user