2024-02-11 00:49:51 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
2023-12-01 03:12:53 +08:00
|
|
|
|
2023-11-30 02:42:59 +08:00
|
|
|
import time
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
|
|
|
|
|
|
|
def time_fn(fn, *args, **kwargs):
|
2024-02-20 13:43:54 +08:00
|
|
|
msg = kwargs.pop("msg", None)
|
|
|
|
if msg:
|
|
|
|
print(f"Timing {msg} ...", end=" ")
|
|
|
|
else:
|
|
|
|
print(f"Timing {fn.__name__} ...", end=" ")
|
2023-11-30 02:42:59 +08:00
|
|
|
|
|
|
|
# warmup
|
|
|
|
for _ in range(5):
|
|
|
|
mx.eval(fn(*args, **kwargs))
|
|
|
|
|
|
|
|
num_iters = 100
|
|
|
|
tic = time.perf_counter()
|
|
|
|
for _ in range(num_iters):
|
|
|
|
x = mx.eval(fn(*args, **kwargs))
|
|
|
|
toc = time.perf_counter()
|
|
|
|
|
|
|
|
msec = 1e3 * (toc - tic) / num_iters
|
|
|
|
print(f"{msec:.5f} msec")
|
2024-02-11 00:49:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
def measure_runtime(fn, **kwargs):
|
|
|
|
# Warmup
|
|
|
|
for _ in range(5):
|
|
|
|
fn(**kwargs)
|
|
|
|
|
|
|
|
tic = time.time()
|
2024-02-14 09:47:41 +08:00
|
|
|
iters = 100
|
2024-02-11 00:49:51 +08:00
|
|
|
for _ in range(iters):
|
|
|
|
fn(**kwargs)
|
|
|
|
return (time.time() - tic) * 1000 / iters
|