mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
![]() |
import argparse
|
||
|
|
||
|
import matplotlib
|
||
|
import mlx.core as mx
|
||
|
import numpy as np
|
||
|
from time_utils import measure_runtime
|
||
|
|
||
|
matplotlib.use("Agg")
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
|
||
|
def had(x):
|
||
|
y = mx.hadamard_transform(x)
|
||
|
mx.eval(y)
|
||
|
|
||
|
|
||
|
def copy(x):
|
||
|
y = x + 1.0
|
||
|
mx.eval(y)
|
||
|
|
||
|
|
||
|
def run(dtype):
|
||
|
system_size = 2**26
|
||
|
outputs = {}
|
||
|
for test_fn in (had, copy):
|
||
|
for m in [1, 12, 20, 28]:
|
||
|
if test_fn == copy:
|
||
|
key = "copy"
|
||
|
elif m == 1:
|
||
|
key = "had_2^k"
|
||
|
else:
|
||
|
key = "had_m*2^k"
|
||
|
outputs.setdefault(key, {})
|
||
|
for k in range(7, 14):
|
||
|
n = m * 2**k
|
||
|
if n > 2**15:
|
||
|
continue
|
||
|
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||
|
x = mx.array(x_np)
|
||
|
runtime_ms = measure_runtime(test_fn, x=x)
|
||
|
bytes_per_gb = 1e9
|
||
|
ms_per_s = 1e3
|
||
|
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||
|
bandwidth_gb = (
|
||
|
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||
|
)
|
||
|
print(n, bandwidth_gb)
|
||
|
outputs[key][n] = bandwidth_gb
|
||
|
|
||
|
colors = {
|
||
|
"copy": "black",
|
||
|
"had_2^k": "steelblue",
|
||
|
"had_m*2^k": "skyblue",
|
||
|
}
|
||
|
for key, output in outputs.items():
|
||
|
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||
|
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||
|
plt.xlabel("N")
|
||
|
plt.ylabel("Bandwidth (GB/s)")
|
||
|
plt.legend()
|
||
|
plt.savefig(f"bench_{dtype.__name__}.png")
|
||
|
plt.clf()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--fp16", action="store_true")
|
||
|
args = parser.parse_args()
|
||
|
dtype = np.float16 if args.fp16 else np.float32
|
||
|
run(dtype)
|