From 9cf82a0d43858a956d99a5e3f7edf32ccfdbaec9 Mon Sep 17 00:00:00 2001 From: adhishthite Date: Thu, 7 Dec 2023 00:07:42 +0530 Subject: [PATCH] Benchmark all models if user allows. --- .gitignore | 2 + whisper/benchmark.py | 99 +++++++++++++++++++++++++------------------- 2 files changed, 58 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index b6e47617..51288c78 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,5 @@ dmypy.json # Pyre type checker .pyre/ +.idea/ +.vscode/ diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 880453a0..9df6b500 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. +import sys import time import mlx.core as mx @@ -48,46 +49,58 @@ def everything(): if __name__ == "__main__": - feat_time = timer(feats) - print(f"Feature time {feat_time:.3f}") - mels = feats()[None] - tokens = mx.array( - [ - 50364, - 1396, - 264, - 665, - 5133, - 23109, - 25462, - 264, - 6582, - 293, - 750, - 632, - 42841, - 292, - 370, - 938, - 294, - 4054, - 293, - 12653, - 356, - 50620, - 50620, - 23563, - 322, - 3312, - 13, - 50680, - ], - mx.int32, - )[None] - model = load_models.load_model("tiny") - model_forward_time = timer(model_forward, model, mels, tokens) - print(f"Model forward time {model_forward_time:.3f}") - decode_time = timer(decode, model, mels) - print(f"Decode time {decode_time:.3f}") - everything_time = timer(everything) - print(f"Everything time {everything_time:.3f}") + + # get command line arguments without 3rd party libraries + # the command line argument to benchmark all models is "all" + models = ["tiny"] + if len(sys.argv) > 1: + if sys.argv[1] == "--all": + models = ["tiny", "small", "medium", "large"] + + for model_name in models: + feat_time = timer(feats) + + print(f"\nModel: {model_name.upper()}") + print(f"\nFeature time {feat_time:.3f}") + mels = feats()[None] + tokens = mx.array( + [ + 50364, + 1396, + 264, + 665, + 5133, + 23109, + 25462, + 264, + 6582, + 293, + 750, + 632, + 42841, + 292, + 370, + 938, + 294, + 4054, + 293, + 12653, + 356, + 50620, + 50620, + 23563, + 322, + 3312, + 13, + 50680, + ], + mx.int32, + )[None] + model = load_models.load_model(f"{model_name}") + model_forward_time = timer(model_forward, model, mels, tokens) + print(f"Model forward time {model_forward_time:.3f}") + decode_time = timer(decode, model, mels) + print(f"Decode time {decode_time:.3f}") + everything_time = timer(everything) + print(f"Everything time {everything_time:.3f}") + print(f"\n{'-----' * 10}\n")