extract correct model dimensions and use argparse

This commit is contained in:
dimopep
2023-12-27 23:57:32 +01:00
parent 000b15d563
commit ba01b969ce

View File

@@ -1,5 +1,5 @@
# Copyright © 2023 Apple Inc.
import argparse
import sys
import time
@@ -9,6 +9,11 @@ from whisper import audio, decoding, load_models, transcribe
audio_file = "whisper/assets/ls_test.flac"
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument("-all", action="store_true", help="Use all available models, i.e. tiny,small,medium,large-v1,large-v2,large-v3")
parser.add_argument("-m", "--models", type=str, help="Specify models as a comma-separated list (e.g., tiny,small,medium)")
return parser.parse_args()
def timer(fn, *args):
for _ in range(5):
@@ -46,27 +51,22 @@ def everything():
if __name__ == "__main__":
# get command line arguments without 3rd party libraries
# the command line argument to benchmark all models is "all"
models = ["tiny"]
for i, arg in enumerate(sys.argv):
if arg == "-all":
models = ["tiny", "small", "medium", "large-v1", "large-v3"]
break
elif arg in ("-m", "--models") and i + 1 < len(sys.argv):
models = sys.argv[i + 1].split(",")
args = parse_arguments()
if args.all:
models = ["tiny", "small", "medium", "large-v1", "large-v2", "large-v3"]
elif args.models:
models = args.models.split(",")
else:
models = ["tiny"]
# Rest of your code using the 'models' list
print("Selected models:", models)
feat_time = timer(feats)
print(f"\nFeature time {feat_time:.3f}")
for model_name in models:
# as long large "points" to "large-v3"
if model_name == "large" or model_name == "large-v3":
n_mels = 128
else:
n_mels = 80
mels = feats(n_mels)[None].astype(mx.float16)
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(
[
@@ -102,6 +102,7 @@ if __name__ == "__main__":
mx.int32,
)[None]
model = load_models.load_model(f"{model_name}", dtype=mx.float16)
mels = feats(model.dims.n_mels)[None].astype(mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)