| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | # Copyright © 2023 Apple Inc. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | import argparse | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-08 20:31:47 +01:00
										 |  |  | import mlx.core as mx | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | from time_utils import time_fn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_add(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     b = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     mx.eval(a, b) | 
					
						
							|  |  |  |     time_fn(mx.add, a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     aT = mx.transpose(a, [0, 2, 1]) | 
					
						
							|  |  |  |     mx.eval(aT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def transpose_add(a, b): | 
					
						
							|  |  |  |         return mx.add(a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(transpose_add, aT, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     b = mx.random.uniform(shape=(1024,)) | 
					
						
							|  |  |  |     mx.eval(b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def slice_add(a, b): | 
					
						
							|  |  |  |         return mx.add(a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(slice_add, a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     b = mx.reshape(b, (1, 1024, 1)) | 
					
						
							|  |  |  |     mx.eval(b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def mid_slice_add(a, b): | 
					
						
							|  |  |  |         return mx.add(a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(mid_slice_add, a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_matmul(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(1024, 1024)) | 
					
						
							|  |  |  |     b = mx.random.uniform(shape=(1024, 1024)) | 
					
						
							|  |  |  |     mx.eval(a, b) | 
					
						
							|  |  |  |     time_fn(mx.matmul, a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-29 11:19:38 -08:00
										 |  |  | def time_maximum(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     b = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     mx.eval(a, b) | 
					
						
							|  |  |  |     time_fn(mx.maximum, a, b) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-09 11:26:27 -07:00
										 |  |  | def time_max(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     a[1, 1] = mx.nan | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  |     time_fn(mx.max, a, 0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-10 06:20:43 -07:00
										 |  |  | def time_min(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(32, 1024, 1024)) | 
					
						
							|  |  |  |     a[1, 1] = mx.nan | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  |     time_fn(mx.min, a, 0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  | def time_negative(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(10000, 1000)) | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def negative(a): | 
					
						
							|  |  |  |         return -a | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(negative, a) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_exp(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(1000, 100)) | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  |     time_fn(mx.exp, a) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_logsumexp(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(64, 10, 10000)) | 
					
						
							|  |  |  |     mx.eval(a) | 
					
						
							|  |  |  |     time_fn(mx.logsumexp, a, axis=-1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_take(): | 
					
						
							|  |  |  |     a = mx.random.uniform(shape=(10000, 500)) | 
					
						
							|  |  |  |     ids = mx.random.randint(low=0, high=10000, shape=(20, 10)) | 
					
						
							|  |  |  |     ids = [mx.reshape(idx, (-1,)) for idx in ids] | 
					
						
							|  |  |  |     mx.eval(ids) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def random_take(): | 
					
						
							|  |  |  |         return [mx.take(a, idx, 0) for idx in ids] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(random_take) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_reshape_transposed(): | 
					
						
							|  |  |  |     x = mx.random.uniform(shape=(256, 256, 128)) | 
					
						
							|  |  |  |     mx.eval(x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reshape_transposed(): | 
					
						
							|  |  |  |         return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_fn(reshape_transposed) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser("MLX benchmarks.") | 
					
						
							|  |  |  |     parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |     if args.gpu: | 
					
						
							|  |  |  |         mx.set_default_device(mx.gpu) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         mx.set_default_device(mx.cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     time_add() | 
					
						
							|  |  |  |     time_matmul() | 
					
						
							| 
									
										
										
										
											2025-07-10 06:20:43 -07:00
										 |  |  |     time_min() | 
					
						
							| 
									
										
										
										
											2025-07-09 11:26:27 -07:00
										 |  |  |     time_max() | 
					
						
							| 
									
										
										
										
											2024-01-29 11:19:38 -08:00
										 |  |  |     time_maximum() | 
					
						
							| 
									
										
										
										
											2023-11-29 10:30:41 -08:00
										 |  |  |     time_exp() | 
					
						
							|  |  |  |     time_negative() | 
					
						
							|  |  |  |     time_logsumexp() | 
					
						
							|  |  |  |     time_take() | 
					
						
							|  |  |  |     time_reshape_transposed() |