Benchmark all models if user allows.

This commit is contained in:
adhishthite 2023-12-07 00:07:42 +05:30
parent 0bf5d0e3bc
commit 9cf82a0d43
2 changed files with 58 additions and 43 deletions

2
.gitignore vendored
View File

@ -127,3 +127,5 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
.idea/
.vscode/

View File

@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import sys
import time import time
import mlx.core as mx import mlx.core as mx
@ -48,46 +49,58 @@ def everything():
if __name__ == "__main__": if __name__ == "__main__":
feat_time = timer(feats)
print(f"Feature time {feat_time:.3f}") # get command line arguments without 3rd party libraries
mels = feats()[None] # the command line argument to benchmark all models is "all"
tokens = mx.array( models = ["tiny"]
[ if len(sys.argv) > 1:
50364, if sys.argv[1] == "--all":
1396, models = ["tiny", "small", "medium", "large"]
264,
665, for model_name in models:
5133, feat_time = timer(feats)
23109,
25462, print(f"\nModel: {model_name.upper()}")
264, print(f"\nFeature time {feat_time:.3f}")
6582, mels = feats()[None]
293, tokens = mx.array(
750, [
632, 50364,
42841, 1396,
292, 264,
370, 665,
938, 5133,
294, 23109,
4054, 25462,
293, 264,
12653, 6582,
356, 293,
50620, 750,
50620, 632,
23563, 42841,
322, 292,
3312, 370,
13, 938,
50680, 294,
], 4054,
mx.int32, 293,
)[None] 12653,
model = load_models.load_model("tiny") 356,
model_forward_time = timer(model_forward, model, mels, tokens) 50620,
print(f"Model forward time {model_forward_time:.3f}") 50620,
decode_time = timer(decode, model, mels) 23563,
print(f"Decode time {decode_time:.3f}") 322,
everything_time = timer(everything) 3312,
print(f"Everything time {everything_time:.3f}") 13,
50680,
],
mx.int32,
)[None]
model = load_models.load_model(f"{model_name}")
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything)
print(f"Everything time {everything_time:.3f}")
print(f"\n{'-----' * 10}\n")