mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
63 lines
1.3 KiB
Python
63 lines
1.3 KiB
Python
# Copyright © 2023 Apple Inc.
|
|
|
|
import argparse
|
|
|
|
import mlx.core as mx
|
|
from time_utils import time_fn
|
|
|
|
B = 8
|
|
T = 1024
|
|
D = 512
|
|
|
|
|
|
def time_batch_matmul():
|
|
mx.random.seed(3)
|
|
a = mx.random.uniform(shape=(B, T, D))
|
|
b = mx.random.uniform(shape=(D, D))
|
|
c = mx.random.uniform(shape=(B, T, D))
|
|
mx.eval(a, b, c)
|
|
|
|
time_fn(mx.matmul, a, b)
|
|
|
|
def batch_vjp_first():
|
|
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
|
|
|
time_fn(batch_vjp_first)
|
|
|
|
def batch_vjp_second():
|
|
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
|
|
|
time_fn(batch_vjp_second)
|
|
|
|
|
|
def time_unbatch_matmul():
|
|
mx.random.seed(3)
|
|
a = mx.random.uniform(shape=(B * T, D))
|
|
b = mx.random.uniform(shape=(D, D))
|
|
c = mx.random.uniform(shape=(B * T, D))
|
|
mx.eval(a, b, c)
|
|
time_fn(mx.matmul, a, b)
|
|
|
|
def unbatch_vjp_first():
|
|
return mx.matmul(c, mx.transpose(b))
|
|
|
|
time_fn(unbatch_vjp_first)
|
|
|
|
def unbatch_vjp_second():
|
|
return mx.matmul(mx.transpose(a), c)
|
|
|
|
time_fn(unbatch_vjp_second)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser("MLX benchmarks.")
|
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
args = parser.parse_args()
|
|
if args.gpu:
|
|
mx.set_default_device(mx.gpu)
|
|
else:
|
|
mx.set_default_device(mx.cpu)
|
|
|
|
time_batch_matmul()
|
|
time_unbatch_matmul()
|