diff --git a/whisper/benchmark.py b/whisper/benchmark.py index c5ff3e2a..b87a55aa 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -66,7 +66,7 @@ def decode(model, mels): def everything(model_path): - return transcribe(audio_file, model_path=model_path) + return transcribe(audio_file, path_or_hf_repo=model_path) if __name__ == "__main__": @@ -128,7 +128,7 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(model_path, dtype=mx.float16) + model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16) mels = feats(model.dims.n_mels)[None].astype(mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}")