Fix TypeError in whisper benchmark script (#306)

* Add missing keyword to the decoding options

* Reverting last commit

* Fixing transcribe keyword in benckmark.py

* Add argument name to load_model

This is intended to avoid confusion
This commit is contained in:
Alexandre Boucaud 2024-01-12 22:07:15 +01:00 committed by GitHub
parent ef93979973
commit 3ac731dd4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -66,7 +66,7 @@ def decode(model, mels):
def everything(model_path): def everything(model_path):
return transcribe(audio_file, model_path=model_path) return transcribe(audio_file, path_or_hf_repo=model_path)
if __name__ == "__main__": if __name__ == "__main__":
@ -128,7 +128,7 @@ if __name__ == "__main__":
], ],
mx.int32, mx.int32,
)[None] )[None]
model = load_models.load_model(model_path, dtype=mx.float16) model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16)
mels = feats(model.dims.n_mels)[None].astype(mx.float16) mels = feats(model.dims.n_mels)[None].astype(mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens) model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}") print(f"Model forward time {model_forward_time:.3f}")