Merge pull request #23 from adhishthite/main

feat: benchmark all models (with `--all` flag)
This commit is contained in:
Awni Hannun 2023-12-06 21:14:41 -08:00 committed by GitHub
commit 76c78aa486
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 43 deletions

2
.gitignore vendored
View File

@ -127,3 +127,5 @@ dmypy.json
# Pyre type checker
.pyre/
.idea/
.vscode/

View File

@ -1,5 +1,6 @@
# Copyright © 2023 Apple Inc.
import sys
import time
import mlx.core as mx
@ -48,8 +49,19 @@ def everything():
if __name__ == "__main__":
# get command line arguments without 3rd party libraries
# the command line argument to benchmark all models is "all"
models = ["tiny"]
if len(sys.argv) > 1:
if sys.argv[1] == "--all":
models = ["tiny", "small", "medium", "large"]
for model_name in models:
feat_time = timer(feats)
print(f"Feature time {feat_time:.3f}")
print(f"\nModel: {model_name.upper()}")
print(f"\nFeature time {feat_time:.3f}")
mels = feats()[None]
tokens = mx.array(
[
@ -84,10 +96,11 @@ if __name__ == "__main__":
],
mx.int32,
)[None]
model = load_models.load_model("tiny")
model = load_models.load_model(f"{model_name}")
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything)
print(f"Everything time {everything_time:.3f}")
print(f"\n{'-----' * 10}\n")