mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
	
	
		
			124 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			124 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import argparse
							 | 
						||
| 
								 | 
							
								import math
							 | 
						||
| 
								 | 
							
								import os
							 | 
						||
| 
								 | 
							
								import subprocess
							 | 
						||
| 
								 | 
							
								import time
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mlx.core as mx
							 | 
						||
| 
								 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
							 | 
						||
| 
								 | 
							
								device_name = device_name.decode("utf-8").strip("\n")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								N_warmup = 10
							 | 
						||
| 
								 | 
							
								N_iter_bench = 100
							 | 
						||
| 
								 | 
							
								N_iter_func = 5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def bench(f, a, b):
							 | 
						||
| 
								 | 
							
								    for i in range(N_warmup):
							 | 
						||
| 
								 | 
							
								        f(a, b)
							 | 
						||
| 
								 | 
							
								    torch.mps.synchronize()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    s = time.perf_counter_ns()
							 | 
						||
| 
								 | 
							
								    for i in range(N_iter_bench):
							 | 
						||
| 
								 | 
							
								        f(a, b)
							 | 
						||
| 
								 | 
							
								    e = time.perf_counter_ns()
							 | 
						||
| 
								 | 
							
								    return (e - s) * 1e-9
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def make_mx_conv_1D(strides=1, padding=0, groups=1):
							 | 
						||
| 
								 | 
							
								    def mx_conv_1D(a, b):
							 | 
						||
| 
								 | 
							
								        ys = []
							 | 
						||
| 
								 | 
							
								        for _ in range(N_iter_func):
							 | 
						||
| 
								 | 
							
								            y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
							 | 
						||
| 
								 | 
							
								            ys.append(y)
							 | 
						||
| 
								 | 
							
								        mx.eval(ys)
							 | 
						||
| 
								 | 
							
								        return ys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return mx_conv_1D
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def make_pt_conv_1D(strides=1, padding=0, groups=1):
							 | 
						||
| 
								 | 
							
								    @torch.no_grad()
							 | 
						||
| 
								 | 
							
								    def pt_conv_1D(a, b):
							 | 
						||
| 
								 | 
							
								        ys = []
							 | 
						||
| 
								 | 
							
								        for _ in range(N_iter_func):
							 | 
						||
| 
								 | 
							
								            y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
							 | 
						||
| 
								 | 
							
								            ys.append(y)
							 | 
						||
| 
								 | 
							
								        torch.mps.synchronize()
							 | 
						||
| 
								 | 
							
								        return ys
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return pt_conv_1D
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
							 | 
						||
| 
								 | 
							
								    scale = 1.0 / math.sqrt(wH * C)
							 | 
						||
| 
								 | 
							
								    a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
							 | 
						||
| 
								 | 
							
								    b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    a_mx = mx.array(a_np)
							 | 
						||
| 
								 | 
							
								    b_mx = mx.array(b_np)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
							 | 
						||
| 
								 | 
							
								    b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    torch.mps.synchronize()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    f_mx = make_mx_conv_1D(strides, padding, groups)
							 | 
						||
| 
								 | 
							
								    f_pt = make_pt_conv_1D(strides, padding, groups)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    time_torch = bench(f_pt, a_pt, b_pt)
							 | 
						||
| 
								 | 
							
								    time_mlx = bench(f_mx, a_mx, b_mx)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
							 | 
						||
| 
								 | 
							
								    out_pt = torch.conv1d(
							 | 
						||
| 
								 | 
							
								        a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    out_pt = torch.permute(out_pt, (0, 2, 1))
							 | 
						||
| 
								 | 
							
								    out_pt = out_pt.numpy(force=True)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    atol = 2e-5 if np_dtype == np.float32 else 1e-4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if not np.allclose(out_pt, out_mx, atol=atol):
							 | 
						||
| 
								 | 
							
								        print(
							 | 
						||
| 
								 | 
							
								            f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return time_mlx, time_torch
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						||
| 
								 | 
							
								    parser = argparse.ArgumentParser(description="Run conv benchmarks")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    dtypes = ("float32",)
							 | 
						||
| 
								 | 
							
								    shapes = (
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 1),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 2),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 4),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 8),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 8),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 16),
							 | 
						||
| 
								 | 
							
								        (4, 32, 32, 5, 32, 1, 2, 32),
							 | 
						||
| 
								 | 
							
								        (4, 32, 256, 5, 512, 1, 2, 2),
							 | 
						||
| 
								 | 
							
								        (4, 32, 256, 5, 512, 1, 2, 128),
							 | 
						||
| 
								 | 
							
								        (4, 32, 256, 5, 512, 1, 2, 256),
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for dtype in dtypes:
							 | 
						||
| 
								 | 
							
								        print("(N,  iH,  C),  (O,  wH,  C),   dtype,  stride, pads, groups, diff%")
							 | 
						||
| 
								 | 
							
								        for N, iH, C, wH, O, strides, padding, groups in shapes:
							 | 
						||
| 
								 | 
							
								            np_dtype = getattr(np, dtype)
							 | 
						||
| 
								 | 
							
								            time_mlx, time_torch = bench_shape(
							 | 
						||
| 
								 | 
							
								                N, iH, C, wH, O, strides, padding, np_dtype, groups
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            diff = time_torch / time_mlx - 1.0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            print(
							 | 
						||
| 
								 | 
							
								                f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if time_mlx >= 2.0 * time_torch:
							 | 
						||
| 
								 | 
							
								                print("ATTENTION ^^^^^^^")
							 |