Fix TypeError in whisper benchmark script (#306)

* Add missing keyword to the decoding options

* Reverting last commit

* Fixing transcribe keyword in benckmark.py

* Add argument name to load_model

This is intended to avoid confusion
This commit is contained in:
Alexandre Boucaud 2024-01-12 22:07:15 +01:00 committed by GitHub
parent ef93979973
commit 3ac731dd4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -66,7 +66,7 @@ def decode(model, mels):
def everything(model_path):
return transcribe(audio_file, model_path=model_path)
return transcribe(audio_file, path_or_hf_repo=model_path)
if __name__ == "__main__":
@ -128,7 +128,7 @@ if __name__ == "__main__":
],
mx.int32,
)[None]
model = load_models.load_model(model_path, dtype=mx.float16)
model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16)
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")