mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
	
	
		
			71 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			71 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import argparse
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import matplotlib
							 | 
						||
| 
								 | 
							
								import mlx.core as mx
							 | 
						||
| 
								 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								from time_utils import measure_runtime
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								matplotlib.use("Agg")
							 | 
						||
| 
								 | 
							
								import matplotlib.pyplot as plt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def had(x):
							 | 
						||
| 
								 | 
							
								    y = mx.hadamard_transform(x)
							 | 
						||
| 
								 | 
							
								    mx.eval(y)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def copy(x):
							 | 
						||
| 
								 | 
							
								    y = x + 1.0
							 | 
						||
| 
								 | 
							
								    mx.eval(y)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def run(dtype):
							 | 
						||
| 
								 | 
							
								    system_size = 2**26
							 | 
						||
| 
								 | 
							
								    outputs = {}
							 | 
						||
| 
								 | 
							
								    for test_fn in (had, copy):
							 | 
						||
| 
								 | 
							
								        for m in [1, 12, 20, 28]:
							 | 
						||
| 
								 | 
							
								            if test_fn == copy:
							 | 
						||
| 
								 | 
							
								                key = "copy"
							 | 
						||
| 
								 | 
							
								            elif m == 1:
							 | 
						||
| 
								 | 
							
								                key = "had_2^k"
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                key = "had_m*2^k"
							 | 
						||
| 
								 | 
							
								            outputs.setdefault(key, {})
							 | 
						||
| 
								 | 
							
								            for k in range(7, 14):
							 | 
						||
| 
								 | 
							
								                n = m * 2**k
							 | 
						||
| 
								 | 
							
								                if n > 2**15:
							 | 
						||
| 
								 | 
							
								                    continue
							 | 
						||
| 
								 | 
							
								                x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
							 | 
						||
| 
								 | 
							
								                x = mx.array(x_np)
							 | 
						||
| 
								 | 
							
								                runtime_ms = measure_runtime(test_fn, x=x)
							 | 
						||
| 
								 | 
							
								                bytes_per_gb = 1e9
							 | 
						||
| 
								 | 
							
								                ms_per_s = 1e3
							 | 
						||
| 
								 | 
							
								                bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
							 | 
						||
| 
								 | 
							
								                bandwidth_gb = (
							 | 
						||
| 
								 | 
							
								                    system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                print(n, bandwidth_gb)
							 | 
						||
| 
								 | 
							
								                outputs[key][n] = bandwidth_gb
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    colors = {
							 | 
						||
| 
								 | 
							
								        "copy": "black",
							 | 
						||
| 
								 | 
							
								        "had_2^k": "steelblue",
							 | 
						||
| 
								 | 
							
								        "had_m*2^k": "skyblue",
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    for key, output in outputs.items():
							 | 
						||
| 
								 | 
							
								        plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
							 | 
						||
| 
								 | 
							
								    plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
							 | 
						||
| 
								 | 
							
								    plt.xlabel("N")
							 | 
						||
| 
								 | 
							
								    plt.ylabel("Bandwidth (GB/s)")
							 | 
						||
| 
								 | 
							
								    plt.legend()
							 | 
						||
| 
								 | 
							
								    plt.savefig(f"bench_{dtype.__name__}.png")
							 | 
						||
| 
								 | 
							
								    plt.clf()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						||
| 
								 | 
							
								    parser = argparse.ArgumentParser()
							 | 
						||
| 
								 | 
							
								    parser.add_argument("--fp16", action="store_true")
							 | 
						||
| 
								 | 
							
								    args = parser.parse_args()
							 | 
						||
| 
								 | 
							
								    dtype = np.float16 if args.fp16 else np.float32
							 | 
						||
| 
								 | 
							
								    run(dtype)
							 |