| 
									
										
										
										
											2024-02-10 08:49:51 -08:00
										 |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  | import time | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def time_fn(fn, *args, **kwargs): | 
					
						
							| 
									
										
										
										
											2024-02-19 21:43:54 -08:00
										 |  |  |     msg = kwargs.pop("msg", None) | 
					
						
							|  |  |  |     if msg: | 
					
						
							|  |  |  |         print(f"Timing {msg} ...", end=" ") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         print(f"Timing {fn.__name__} ...", end=" ") | 
					
						
							| 
									
										
										
										
											2023-11-29 10:42:59 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # warmup | 
					
						
							|  |  |  |     for _ in range(5): | 
					
						
							|  |  |  |         mx.eval(fn(*args, **kwargs)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     num_iters = 100 | 
					
						
							|  |  |  |     tic = time.perf_counter() | 
					
						
							|  |  |  |     for _ in range(num_iters): | 
					
						
							|  |  |  |         x = mx.eval(fn(*args, **kwargs)) | 
					
						
							|  |  |  |     toc = time.perf_counter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     msec = 1e3 * (toc - tic) / num_iters | 
					
						
							|  |  |  |     print(f"{msec:.5f} msec") | 
					
						
							| 
									
										
										
										
											2024-02-10 08:49:51 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def measure_runtime(fn, **kwargs): | 
					
						
							|  |  |  |     # Warmup | 
					
						
							|  |  |  |     for _ in range(5): | 
					
						
							|  |  |  |         fn(**kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tic = time.time() | 
					
						
							| 
									
										
										
										
											2024-02-13 17:47:41 -08:00
										 |  |  |     iters = 100 | 
					
						
							| 
									
										
										
										
											2024-02-10 08:49:51 -08:00
										 |  |  |     for _ in range(iters): | 
					
						
							|  |  |  |         fn(**kwargs) | 
					
						
							|  |  |  |     return (time.time() - tic) * 1000 / iters |