diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 2b4c237f..2e69ec45 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. import argparse +import os +import subprocess import sys import time @@ -12,6 +14,12 @@ audio_file = "whisper/assets/ls_test.flac" def parse_arguments(): parser = argparse.ArgumentParser(description="Benchmark script.") + parser.add_argument( + "--mlx-dir", + type=str, + default="mlx_models", + help="The folder of MLX models", + ) parser.add_argument( "--all", action="store_true", @@ -57,8 +65,8 @@ def decode(model, mels): return decoding.decode(model, mels) -def everything(model_name): - return transcribe(audio_file, model=model_name) +def everything(model_path): + return transcribe(audio_file, model_path=model_path) if __name__ == "__main__": @@ -76,6 +84,11 @@ if __name__ == "__main__": print(f"\nFeature time {feat_time:.3f}") for model_name in models: + model_path = f"{args.mlx_dir}/{model_name}" + if not os.path.exists(model_path): + print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion") + subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True) + print(f"\nModel: {model_name.upper()}") tokens = mx.array( [ @@ -110,12 +123,12 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(f"{model_name}", dtype=mx.float16) + model = load_models.load_model(model_path, dtype=mx.float16) mels = feats(model.dims.n_mels)[None].astype(mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels) print(f"Decode time {decode_time:.3f}") - everything_time = timer(everything, model_name) + everything_time = timer(everything, model_path) print(f"Everything time {everything_time:.3f}") print(f"\n{'-----' * 10}\n")