This commit is contained in:
Awni Hannun
2023-12-28 13:49:12 -08:00
parent ba01b969ce
commit d1ca6919af

View File

@@ -9,12 +9,23 @@ from whisper import audio, decoding, load_models, transcribe
audio_file = "whisper/assets/ls_test.flac"
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument("-all", action="store_true", help="Use all available models, i.e. tiny,small,medium,large-v1,large-v2,large-v3")
parser.add_argument("-m", "--models", type=str, help="Specify models as a comma-separated list (e.g., tiny,small,medium)")
parser.add_argument(
"--all",
action="store_true",
help="Use all available models, i.e. tiny,small,medium,large-v3",
)
parser.add_argument(
"-m",
"--models",
type=str,
help="Specify models as a comma-separated list (e.g., tiny,small,medium)",
)
return parser.parse_args()
def timer(fn, *args):
for _ in range(5):
fn(*args)
@@ -51,10 +62,9 @@ def everything():
if __name__ == "__main__":
# the command line argument to benchmark all models is "all"
args = parse_arguments()
if args.all:
models = ["tiny", "small", "medium", "large-v1", "large-v2", "large-v3"]
models = ["tiny", "small", "medium", "large-v3"]
elif args.models:
models = args.models.split(",")
else: