mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
angelos's commit files
This commit is contained in:
60
benchmarks/python/batch_matmul_bench.py
Normal file
60
benchmarks/python/batch_matmul_bench.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 8
|
||||
T = 1024
|
||||
D = 512
|
||||
|
||||
|
||||
def time_batch_matmul():
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B, T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B, T, D))
|
||||
mx.eval(a, b, c)
|
||||
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def batch_vjp_first():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
||||
|
||||
time_fn(batch_vjp_first)
|
||||
|
||||
def batch_vjp_second():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
||||
|
||||
time_fn(batch_vjp_second)
|
||||
|
||||
|
||||
def time_unbatch_matmul(key):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B * T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B * T, D))
|
||||
mx.eval(a, b, c)
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def unbatch_vjp_first():
|
||||
return mx.matmul(c, mx.transpose(b))
|
||||
|
||||
time_fn(unbatch_vjp_first)
|
||||
|
||||
def unbatch_vjp_second():
|
||||
return mx.matmul(mx.transpose(a), c)
|
||||
|
||||
time_fn(unbatch_vjp_second)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_batch_matmul()
|
||||
time_unbatch_matmul()
|
20
benchmarks/python/time_utils.py
Normal file
20
benchmarks/python/time_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
Reference in New Issue
Block a user