From 07c163d9d9d4d283c157e7bec7a87b576322d238 Mon Sep 17 00:00:00 2001 From: Dimo Date: Thu, 28 Dec 2023 22:50:35 +0100 Subject: [PATCH] [Whisper] Large-v3 requires 128 Mel frequency bins (#193) * Large-v3 requires 128 Mel frequency bins * extract correct model dimensions and use argparse * format * format --------- Co-authored-by: Awni Hannun --- whisper/benchmark.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 877bb4f0..2b4c237f 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,5 +1,5 @@ # Copyright © 2023 Apple Inc. - +import argparse import sys import time @@ -10,6 +10,22 @@ from whisper import audio, decoding, load_models, transcribe audio_file = "whisper/assets/ls_test.flac" +def parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark script.") + parser.add_argument( + "--all", + action="store_true", + help="Use all available models, i.e. tiny,small,medium,large-v3", + ) + parser.add_argument( + "-m", + "--models", + type=str, + help="Specify models as a comma-separated list (e.g., tiny,small,medium)", + ) + return parser.parse_args() + + def timer(fn, *args): for _ in range(5): fn(*args) @@ -23,10 +39,10 @@ def timer(fn, *args): return (toc - tic) / num_its -def feats(): +def feats(n_mels: int = 80): data = audio.load_audio(audio_file) data = audio.pad_or_trim(data) - mels = audio.log_mel_spectrogram(data) + mels = audio.log_mel_spectrogram(data, n_mels) mx.eval(mels) return mels @@ -46,20 +62,20 @@ def everything(model_name): if __name__ == "__main__": + args = parse_arguments() + if args.all: + models = ["tiny", "small", "medium", "large-v3"] + elif args.models: + models = args.models.split(",") + else: + models = ["tiny"] - # get command line arguments without 3rd party libraries - # the command line argument to benchmark all models is "all" - models = ["tiny"] - if len(sys.argv) > 1: - if sys.argv[1] == "--all": - models = ["tiny", "small", "medium", "large"] + print("Selected models:", models) feat_time = timer(feats) print(f"\nFeature time {feat_time:.3f}") - mels = feats()[None].astype(mx.float16) for model_name in models: - print(f"\nModel: {model_name.upper()}") tokens = mx.array( [ @@ -95,6 +111,7 @@ if __name__ == "__main__": mx.int32, )[None] model = load_models.load_model(f"{model_name}", dtype=mx.float16) + mels = feats(model.dims.n_mels)[None].astype(mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels)