mlx-examples/whisper/mlx_whisper/cli.py
madroid e196fa3208
Whisper: Support command line (#746)
* Whisper: Add CLI command

* Whisper: Prevent precision loss when converting to words dictionary

* Whisper: disable json ensure_ascii

* Whisper: add cli setup config

* Whisper: pre-commit

* Whisper: Adjust the _ in the command line arguments to -

* nits

* version + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-16 10:35:44 -07:00

237 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright © 2024 Apple Inc.
import argparse
import os
import traceback
import warnings
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
def build_parser():
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
)
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
default=".",
help="Directory to save the outputs",
)
parser.add_argument(
"--output-format",
"-f",
type=str,
default="txt",
choices=["txt", "vtt", "srt", "tsv", "json", "all"],
help="Format of the output file",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Whether to print out progress and debug messages",
)
parser.add_argument(
"--task",
type=str,
default="transcribe",
choices=["transcribe", "translate"],
help="Perform speech recognition ('transcribe') or speech translation ('translate')",
)
parser.add_argument(
"--language",
type=str,
default=None,
choices=sorted(LANGUAGES.keys())
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio, specify None to auto-detect",
)
parser.add_argument(
"--temperature", type=float, default=0, help="Temperature for sampling"
)
parser.add_argument(
"--best-of",
type=optional_int,
default=5,
help="Number of candidates when sampling with non-zero temperature",
)
parser.add_argument(
"--patience",
type=float,
default=None,
help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
)
parser.add_argument(
"--length-penalty",
type=float,
default=None,
help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.",
)
parser.add_argument(
"--suppress-tokens",
type=str,
default="-1",
help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
)
parser.add_argument(
"--initial-prompt",
type=str,
default=None,
help="Optional text to provide as a prompt for the first window.",
)
parser.add_argument(
"--condition-on-previous-text",
type=str2bool,
default=True,
help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
)
parser.add_argument(
"--fp16",
type=str2bool,
default=True,
help="Whether to perform inference in fp16",
)
parser.add_argument(
"--compression-ratio-threshold",
type=optional_float,
default=2.4,
help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
)
parser.add_argument(
"--logprob-threshold",
type=optional_float,
default=-1.0,
help="If the average log probability is lower than this value, treat the decoding as failed",
)
parser.add_argument(
"--no-speech-threshold",
type=optional_float,
default=0.6,
help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence",
)
parser.add_argument(
"--word-timestamps",
type=str2bool,
default=False,
help="Extract word-level timestamps and refine the results based on them",
)
parser.add_argument(
"--prepend-punctuations",
type=str,
default="\"'“¿([{-",
help="If word-timestamps is True, merge these punctuation symbols with the next word",
)
parser.add_argument(
"--append-punctuations",
type=str,
default="\"'.。,!?::”)]}、",
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
)
parser.add_argument(
"--highlight-words",
type=str2bool,
default=False,
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
)
parser.add_argument(
"--max-line-width",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
)
parser.add_argument(
"--max-line-count",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of lines in a segment",
)
parser.add_argument(
"--max-words-per-line",
type=int,
default=None,
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
)
parser.add_argument(
"--hallucination-silence-threshold",
type=optional_float,
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
)
parser.add_argument(
"--clip-timestamps",
type=str,
default="0",
help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
)
return parser
def main():
parser = build_parser()
args = vars(parser.parse_args())
if args["verbose"] is True:
print(f"Args: {args}")
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
writer_args = {arg: args.pop(arg) for arg in word_options}
if not args["word_timestamps"]:
for k, v in writer_args.items():
if v:
argop = k.replace("_", "-")
parser.error(f"--{argop} requires --word-timestamps True")
if writer_args["max_line_count"] and not writer_args["max_line_width"]:
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
try:
result = transcribe(
audio_path,
path_or_hf_repo=path_or_hf_repo,
**args,
)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":
main()