mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			97 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			97 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023-2024 Apple Inc.
 | 
						|
 | 
						|
import argparse
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import torch
 | 
						|
from time_utils import measure_runtime
 | 
						|
 | 
						|
 | 
						|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
 | 
						|
    def scatter(dst, x, idx):
 | 
						|
        dst[tuple(idx)] = x
 | 
						|
        mx.eval(dst)
 | 
						|
 | 
						|
    idx = []
 | 
						|
    for idx_shape in idx_shapes:
 | 
						|
        idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
 | 
						|
    x = mx.random.normal(x_shape).astype(mx.float32)
 | 
						|
    dst = mx.random.normal(dst_shape).astype(mx.float32)
 | 
						|
 | 
						|
    runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
 | 
						|
    print(f"MLX: {runtime:.3f}ms")
 | 
						|
 | 
						|
 | 
						|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
 | 
						|
    def scatter(dst, x, idx, device):
 | 
						|
        dst[tuple(idx)] = x
 | 
						|
        if device == torch.device("mps"):
 | 
						|
            torch.mps.synchronize()
 | 
						|
 | 
						|
    idx = []
 | 
						|
    for idx_shape in idx_shapes:
 | 
						|
        idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
 | 
						|
    x = torch.randn(x_shape, dtype=torch.float32).to(device)
 | 
						|
    dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
 | 
						|
 | 
						|
    runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
 | 
						|
    print(f"PyTorch: {runtime:.3f}ms")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    parser = argparse.ArgumentParser("Gather benchmarks.")
 | 
						|
    parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
 | 
						|
    args = parser.parse_args()
 | 
						|
 | 
						|
    if args.cpu:
 | 
						|
        mx.set_default_device(mx.cpu)
 | 
						|
        device = torch.device("cpu")
 | 
						|
    else:
 | 
						|
        device = torch.device("mps")
 | 
						|
 | 
						|
    dst_shapes = [
 | 
						|
        (10, 64),
 | 
						|
        (100_000, 64),
 | 
						|
        (1_000_000, 64),
 | 
						|
        (100_000,),
 | 
						|
        (200_000,),
 | 
						|
        (20_000_000,),
 | 
						|
        (10000, 64),
 | 
						|
        (100, 64),
 | 
						|
        (100, 10_000, 64),
 | 
						|
        (10, 100, 100, 21),
 | 
						|
        (1_000, 1_000, 10),
 | 
						|
    ]
 | 
						|
    idx_shapes = [
 | 
						|
        [(1_000_000,)],
 | 
						|
        [(1_000_000,)],
 | 
						|
        [(100_000,)],
 | 
						|
        [(1_000_000,)],
 | 
						|
        [(20_000_000,)],
 | 
						|
        [(20_000_000,)],
 | 
						|
        [(1000000,)],
 | 
						|
        [(10000000,)],
 | 
						|
        [(1_000,)],
 | 
						|
        [(10_000,)],
 | 
						|
        [(1_000,), (1_000,)],
 | 
						|
    ]
 | 
						|
    x_shapes = [
 | 
						|
        (1_000_000, 64),
 | 
						|
        (1_000_000, 64),
 | 
						|
        (100_000, 64),
 | 
						|
        (1_000_000,),
 | 
						|
        (20_000_000,),
 | 
						|
        (20_000_000,),
 | 
						|
        (1000000, 64),
 | 
						|
        (10000000, 64),
 | 
						|
        (1_000, 10_000, 64),
 | 
						|
        (10_000, 100, 100, 21),
 | 
						|
        (1_000, 10),
 | 
						|
    ]
 | 
						|
 | 
						|
    for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
 | 
						|
        print("=" * 20)
 | 
						|
        print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
 | 
						|
        benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
 | 
						|
        benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
 |