import argparse import matplotlib import mlx.core as mx import numpy as np from time_utils import measure_runtime matplotlib.use("Agg") import matplotlib.pyplot as plt def had(x): y = mx.hadamard_transform(x) mx.eval(y) def copy(x): y = x + 1.0 mx.eval(y) def run(dtype): system_size = 2**26 outputs = {} for test_fn in (had, copy): for m in [1, 12, 20, 28]: if test_fn == copy: key = "copy" elif m == 1: key = "had_2^k" else: key = "had_m*2^k" outputs.setdefault(key, {}) for k in range(7, 14): n = m * 2**k if n > 2**15: continue x_np = np.random.normal(size=(system_size // n, n)).astype(dtype) x = mx.array(x_np) runtime_ms = measure_runtime(test_fn, x=x) bytes_per_gb = 1e9 ms_per_s = 1e3 bytes_per_had = np.dtype(x_np.dtype).itemsize * 2 bandwidth_gb = ( system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb ) print(n, bandwidth_gb) outputs[key][n] = bandwidth_gb colors = { "copy": "black", "had_2^k": "steelblue", "had_m*2^k": "skyblue", } for key, output in outputs.items(): plt.scatter(output.keys(), output.values(), color=colors[key], label=key) plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}") plt.xlabel("N") plt.ylabel("Bandwidth (GB/s)") plt.legend() plt.savefig(f"bench_{dtype.__name__}.png") plt.clf() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fp16", action="store_true") args = parser.parse_args() dtype = np.float16 if args.fp16 else np.float32 run(dtype)