mlx/benchmarks/python/batch_matmul_bench.py

61 lines
1.3 KiB
Python
Raw Normal View History

2023-11-30 02:42:59 +08:00
import argparse
import mlx.core as mx
from time_utils import time_fn
B = 8
T = 1024
D = 512
def time_batch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B, T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B, T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def batch_vjp_first():
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
time_fn(batch_vjp_first)
def batch_vjp_second():
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
time_fn(batch_vjp_second)
def time_unbatch_matmul(key):
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B * T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def unbatch_vjp_first():
return mx.matmul(c, mx.transpose(b))
time_fn(unbatch_vjp_first)
def unbatch_vjp_second():
return mx.matmul(mx.transpose(a), c)
time_fn(unbatch_vjp_second)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_batch_matmul()
time_unbatch_matmul()