From 3ac731dd4f5aad8fb106d26cb0bb8cd6698e16ac Mon Sep 17 00:00:00 2001 From: Alexandre Boucaud <3065310+aboucaud@users.noreply.github.com> Date: Fri, 12 Jan 2024 22:07:15 +0100 Subject: [PATCH] Fix TypeError in whisper benchmark script (#306) * Add missing keyword to the decoding options * Reverting last commit * Fixing transcribe keyword in benckmark.py * Add argument name to load_model This is intended to avoid confusion --- whisper/benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index c5ff3e2a..b87a55aa 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -66,7 +66,7 @@ def decode(model, mels): def everything(model_path): - return transcribe(audio_file, model_path=model_path) + return transcribe(audio_file, path_or_hf_repo=model_path) if __name__ == "__main__": @@ -128,7 +128,7 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(model_path, dtype=mx.float16) + model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16) mels = feats(model.dims.n_mels)[None].astype(mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}")