mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
angelos's commit files
This commit is contained in:
38
benchmarks/cpp/time_utils.h
Normal file
38
benchmarks/cpp/time_utils.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
#define milliseconds(x) \
|
||||
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
|
||||
#define time_now() std::chrono::high_resolution_clock::now()
|
||||
|
||||
#define TIME(FUNC, ...) \
|
||||
std::cout << "Timing " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
int num_iters = 100;
|
||||
auto start = time_now();
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
}
|
||||
auto end = time_now();
|
||||
return milliseconds(end - start) / static_cast<double>(num_iters);
|
||||
}
|
||||
60
benchmarks/python/batch_matmul_bench.py
Normal file
60
benchmarks/python/batch_matmul_bench.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 8
|
||||
T = 1024
|
||||
D = 512
|
||||
|
||||
|
||||
def time_batch_matmul():
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B, T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B, T, D))
|
||||
mx.eval(a, b, c)
|
||||
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def batch_vjp_first():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
||||
|
||||
time_fn(batch_vjp_first)
|
||||
|
||||
def batch_vjp_second():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
||||
|
||||
time_fn(batch_vjp_second)
|
||||
|
||||
|
||||
def time_unbatch_matmul(key):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B * T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B * T, D))
|
||||
mx.eval(a, b, c)
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def unbatch_vjp_first():
|
||||
return mx.matmul(c, mx.transpose(b))
|
||||
|
||||
time_fn(unbatch_vjp_first)
|
||||
|
||||
def unbatch_vjp_second():
|
||||
return mx.matmul(mx.transpose(a), c)
|
||||
|
||||
time_fn(unbatch_vjp_second)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_batch_matmul()
|
||||
time_unbatch_matmul()
|
||||
20
benchmarks/python/time_utils.py
Normal file
20
benchmarks/python/time_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
Reference in New Issue
Block a user