| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | import time | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import mlx.core as mx | 
					
						
							| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  | from mlx_whisper import audio, decoding, load_models, transcribe | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  | audio_file = "mlx_whisper/assets/ls_test.flac" | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  | def parse_arguments(): | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser(description="Benchmark script.") | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--mlx-dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="mlx_models", | 
					
						
							|  |  |  |         help="The folder of MLX models", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--all", | 
					
						
							|  |  |  |         action="store_true", | 
					
						
							|  |  |  |         help="Use all available models, i.e. tiny,small,medium,large-v3", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "-m", | 
					
						
							|  |  |  |         "--models", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         help="Specify models as a comma-separated list (e.g., tiny,small,medium)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | def timer(fn, *args): | 
					
						
							|  |  |  |     for _ in range(5): | 
					
						
							|  |  |  |         fn(*args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     num_its = 10 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tic = time.perf_counter() | 
					
						
							|  |  |  |     for _ in range(num_its): | 
					
						
							|  |  |  |         fn(*args) | 
					
						
							|  |  |  |     toc = time.perf_counter() | 
					
						
							|  |  |  |     return (toc - tic) / num_its | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  | def feats(n_mels: int = 80): | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |     data = audio.load_audio(audio_file) | 
					
						
							|  |  |  |     data = audio.pad_or_trim(data) | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  |     mels = audio.log_mel_spectrogram(data, n_mels) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  |     mx.eval(mels) | 
					
						
							|  |  |  |     return mels | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def model_forward(model, mels, tokens): | 
					
						
							|  |  |  |     logits = model(mels, tokens) | 
					
						
							|  |  |  |     mx.eval(logits) | 
					
						
							|  |  |  |     return logits | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def decode(model, mels): | 
					
						
							|  |  |  |     return decoding.decode(model, mels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  | def everything(model_path): | 
					
						
							| 
									
										
										
										
											2024-01-12 22:07:15 +01:00
										 |  |  |     return transcribe(audio_file, path_or_hf_repo=model_path) | 
					
						
							| 
									
										
										
										
											2023-11-29 08:17:26 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  |     args = parse_arguments() | 
					
						
							|  |  |  |     if args.all: | 
					
						
							|  |  |  |         models = ["tiny", "small", "medium", "large-v3"] | 
					
						
							|  |  |  |     elif args.models: | 
					
						
							|  |  |  |         models = args.models.split(",") | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         models = ["tiny"] | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  |     print("Selected models:", models) | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-12 07:37:35 -08:00
										 |  |  |     feat_time = timer(feats) | 
					
						
							|  |  |  |     print(f"\nFeature time {feat_time:.3f}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  |     for model_name in models: | 
					
						
							| 
									
										
										
										
											2024-05-02 00:00:02 +08:00
										 |  |  |         model_path = f"mlx-community/whisper-{model_name}-mlx" | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  |         print(f"\nModel: {model_name.upper()}") | 
					
						
							|  |  |  |         tokens = mx.array( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 50364, | 
					
						
							|  |  |  |                 1396, | 
					
						
							|  |  |  |                 264, | 
					
						
							|  |  |  |                 665, | 
					
						
							|  |  |  |                 5133, | 
					
						
							|  |  |  |                 23109, | 
					
						
							|  |  |  |                 25462, | 
					
						
							|  |  |  |                 264, | 
					
						
							|  |  |  |                 6582, | 
					
						
							|  |  |  |                 293, | 
					
						
							|  |  |  |                 750, | 
					
						
							|  |  |  |                 632, | 
					
						
							|  |  |  |                 42841, | 
					
						
							|  |  |  |                 292, | 
					
						
							|  |  |  |                 370, | 
					
						
							|  |  |  |                 938, | 
					
						
							|  |  |  |                 294, | 
					
						
							|  |  |  |                 4054, | 
					
						
							|  |  |  |                 293, | 
					
						
							|  |  |  |                 12653, | 
					
						
							|  |  |  |                 356, | 
					
						
							|  |  |  |                 50620, | 
					
						
							|  |  |  |                 50620, | 
					
						
							|  |  |  |                 23563, | 
					
						
							|  |  |  |                 322, | 
					
						
							|  |  |  |                 3312, | 
					
						
							|  |  |  |                 13, | 
					
						
							|  |  |  |                 50680, | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             mx.int32, | 
					
						
							|  |  |  |         )[None] | 
					
						
							| 
									
										
										
										
											2024-01-12 22:07:15 +01:00
										 |  |  |         model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16) | 
					
						
							| 
									
										
										
										
											2023-12-28 22:50:35 +01:00
										 |  |  |         mels = feats(model.dims.n_mels)[None].astype(mx.float16) | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  |         model_forward_time = timer(model_forward, model, mels, tokens) | 
					
						
							|  |  |  |         print(f"Model forward time {model_forward_time:.3f}") | 
					
						
							|  |  |  |         decode_time = timer(decode, model, mels) | 
					
						
							|  |  |  |         print(f"Decode time {decode_time:.3f}") | 
					
						
							| 
									
										
										
										
											2023-12-29 19:22:15 +01:00
										 |  |  |         everything_time = timer(everything, model_path) | 
					
						
							| 
									
										
										
										
											2023-12-07 00:07:42 +05:30
										 |  |  |         print(f"Everything time {everything_time:.3f}") | 
					
						
							|  |  |  |         print(f"\n{'-----' * 10}\n") |