mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			133 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import argparse
 | |
| 
 | |
| import mlx.core as mx
 | |
| from time_utils import time_fn
 | |
| 
 | |
| 
 | |
| def time_add():
 | |
|     a = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     b = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     mx.eval(a, b)
 | |
|     time_fn(mx.add, a, b)
 | |
| 
 | |
|     aT = mx.transpose(a, [0, 2, 1])
 | |
|     mx.eval(aT)
 | |
| 
 | |
|     def transpose_add(a, b):
 | |
|         return mx.add(a, b)
 | |
| 
 | |
|     time_fn(transpose_add, aT, b)
 | |
| 
 | |
|     b = mx.random.uniform(shape=(1024,))
 | |
|     mx.eval(b)
 | |
| 
 | |
|     def slice_add(a, b):
 | |
|         return mx.add(a, b)
 | |
| 
 | |
|     time_fn(slice_add, a, b)
 | |
| 
 | |
|     b = mx.reshape(b, (1, 1024, 1))
 | |
|     mx.eval(b)
 | |
| 
 | |
|     def mid_slice_add(a, b):
 | |
|         return mx.add(a, b)
 | |
| 
 | |
|     time_fn(mid_slice_add, a, b)
 | |
| 
 | |
| 
 | |
| def time_matmul():
 | |
|     a = mx.random.uniform(shape=(1024, 1024))
 | |
|     b = mx.random.uniform(shape=(1024, 1024))
 | |
|     mx.eval(a, b)
 | |
|     time_fn(mx.matmul, a, b)
 | |
| 
 | |
| 
 | |
| def time_maximum():
 | |
|     a = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     b = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     mx.eval(a, b)
 | |
|     time_fn(mx.maximum, a, b)
 | |
| 
 | |
| 
 | |
| def time_max():
 | |
|     a = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     a[1, 1] = mx.nan
 | |
|     mx.eval(a)
 | |
|     time_fn(mx.max, a, 0)
 | |
| 
 | |
| 
 | |
| def time_min():
 | |
|     a = mx.random.uniform(shape=(32, 1024, 1024))
 | |
|     a[1, 1] = mx.nan
 | |
|     mx.eval(a)
 | |
|     time_fn(mx.min, a, 0)
 | |
| 
 | |
| 
 | |
| def time_negative():
 | |
|     a = mx.random.uniform(shape=(10000, 1000))
 | |
|     mx.eval(a)
 | |
| 
 | |
|     def negative(a):
 | |
|         return -a
 | |
| 
 | |
|     mx.eval(a)
 | |
| 
 | |
|     time_fn(negative, a)
 | |
| 
 | |
| 
 | |
| def time_exp():
 | |
|     a = mx.random.uniform(shape=(1000, 100))
 | |
|     mx.eval(a)
 | |
|     time_fn(mx.exp, a)
 | |
| 
 | |
| 
 | |
| def time_logsumexp():
 | |
|     a = mx.random.uniform(shape=(64, 10, 10000))
 | |
|     mx.eval(a)
 | |
|     time_fn(mx.logsumexp, a, axis=-1)
 | |
| 
 | |
| 
 | |
| def time_take():
 | |
|     a = mx.random.uniform(shape=(10000, 500))
 | |
|     ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
 | |
|     ids = [mx.reshape(idx, (-1,)) for idx in ids]
 | |
|     mx.eval(ids)
 | |
| 
 | |
|     def random_take():
 | |
|         return [mx.take(a, idx, 0) for idx in ids]
 | |
| 
 | |
|     time_fn(random_take)
 | |
| 
 | |
| 
 | |
| def time_reshape_transposed():
 | |
|     x = mx.random.uniform(shape=(256, 256, 128))
 | |
|     mx.eval(x)
 | |
| 
 | |
|     def reshape_transposed():
 | |
|         return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
 | |
| 
 | |
|     time_fn(reshape_transposed)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser("MLX benchmarks.")
 | |
|     parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
 | |
|     args = parser.parse_args()
 | |
|     if args.gpu:
 | |
|         mx.set_default_device(mx.gpu)
 | |
|     else:
 | |
|         mx.set_default_device(mx.cpu)
 | |
| 
 | |
|     time_add()
 | |
|     time_matmul()
 | |
|     time_min()
 | |
|     time_max()
 | |
|     time_maximum()
 | |
|     time_exp()
 | |
|     time_negative()
 | |
|     time_logsumexp()
 | |
|     time_take()
 | |
|     time_reshape_transposed()
 | 
