angelos's commit files

This commit is contained in:
Angelos Katharopoulos
2023-11-29 10:42:59 -08:00
parent 8ca7f9e8e9
commit d1f86272a2
56 changed files with 12350 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
import argparse
import mlx.core as mx
from time_utils import time_fn
B = 8
T = 1024
D = 512
def time_batch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B, T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B, T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def batch_vjp_first():
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
time_fn(batch_vjp_first)
def batch_vjp_second():
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
time_fn(batch_vjp_second)
def time_unbatch_matmul(key):
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B * T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def unbatch_vjp_first():
return mx.matmul(c, mx.transpose(b))
time_fn(unbatch_vjp_first)
def unbatch_vjp_second():
return mx.matmul(mx.transpose(a), c)
time_fn(unbatch_vjp_second)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_batch_matmul()
time_unbatch_matmul()

View File

@@ -0,0 +1,20 @@
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")