mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix TypeError in whisper benchmark script (#306)
* Add missing keyword to the decoding options * Reverting last commit * Fixing transcribe keyword in benckmark.py * Add argument name to load_model This is intended to avoid confusion
This commit is contained in:
parent
ef93979973
commit
3ac731dd4f
@ -66,7 +66,7 @@ def decode(model, mels):
|
|||||||
|
|
||||||
|
|
||||||
def everything(model_path):
|
def everything(model_path):
|
||||||
return transcribe(audio_file, model_path=model_path)
|
return transcribe(audio_file, path_or_hf_repo=model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -128,7 +128,7 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
mx.int32,
|
mx.int32,
|
||||||
)[None]
|
)[None]
|
||||||
model = load_models.load_model(model_path, dtype=mx.float16)
|
model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16)
|
||||||
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
|
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
|
||||||
model_forward_time = timer(model_forward, model, mels, tokens)
|
model_forward_time = timer(model_forward, model, mels, tokens)
|
||||||
print(f"Model forward time {model_forward_time:.3f}")
|
print(f"Model forward time {model_forward_time:.3f}")
|
||||||
|
Loading…
Reference in New Issue
Block a user