mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-04 21:48:09 +08:00 
			
		
		
		
	
		
			
	
	
		
			237 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			237 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								# Copyright © 2024 Apple Inc.
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								import argparse
							 | 
						|||
| 
								 | 
							
								import os
							 | 
						|||
| 
								 | 
							
								import traceback
							 | 
						|||
| 
								 | 
							
								import warnings
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
							 | 
						|||
| 
								 | 
							
								from .transcribe import transcribe
							 | 
						|||
| 
								 | 
							
								from .writers import get_writer
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								def build_parser():
							 | 
						|||
| 
								 | 
							
								    def optional_int(string):
							 | 
						|||
| 
								 | 
							
								        return None if string == "None" else int(string)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    def optional_float(string):
							 | 
						|||
| 
								 | 
							
								        return None if string == "None" else float(string)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    def str2bool(string):
							 | 
						|||
| 
								 | 
							
								        str2val = {"True": True, "False": False}
							 | 
						|||
| 
								 | 
							
								        if string in str2val:
							 | 
						|||
| 
								 | 
							
								            return str2val[string]
							 | 
						|||
| 
								 | 
							
								        else:
							 | 
						|||
| 
								 | 
							
								            raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    parser = argparse.ArgumentParser(
							 | 
						|||
| 
								 | 
							
								        formatter_class=argparse.ArgumentDefaultsHelpFormatter
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "audio", nargs="+", type=str, help="Audio file(s) to transcribe"
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--model",
							 | 
						|||
| 
								 | 
							
								        default="mlx-community/whisper-tiny",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        help="The model directory or hugging face repo",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--output-dir",
							 | 
						|||
| 
								 | 
							
								        "-o",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default=".",
							 | 
						|||
| 
								 | 
							
								        help="Directory to save the outputs",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--output-format",
							 | 
						|||
| 
								 | 
							
								        "-f",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="txt",
							 | 
						|||
| 
								 | 
							
								        choices=["txt", "vtt", "srt", "tsv", "json", "all"],
							 | 
						|||
| 
								 | 
							
								        help="Format of the output file",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--verbose",
							 | 
						|||
| 
								 | 
							
								        type=str2bool,
							 | 
						|||
| 
								 | 
							
								        default=True,
							 | 
						|||
| 
								 | 
							
								        help="Whether to print out progress and debug messages",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--task",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="transcribe",
							 | 
						|||
| 
								 | 
							
								        choices=["transcribe", "translate"],
							 | 
						|||
| 
								 | 
							
								        help="Perform speech recognition ('transcribe') or speech translation ('translate')",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--language",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        choices=sorted(LANGUAGES.keys())
							 | 
						|||
| 
								 | 
							
								        + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
							 | 
						|||
| 
								 | 
							
								        help="Language spoken in the audio, specify None to auto-detect",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--temperature", type=float, default=0, help="Temperature for sampling"
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--best-of",
							 | 
						|||
| 
								 | 
							
								        type=optional_int,
							 | 
						|||
| 
								 | 
							
								        default=5,
							 | 
						|||
| 
								 | 
							
								        help="Number of candidates when sampling with non-zero temperature",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--patience",
							 | 
						|||
| 
								 | 
							
								        type=float,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--length-penalty",
							 | 
						|||
| 
								 | 
							
								        type=float,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--suppress-tokens",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="-1",
							 | 
						|||
| 
								 | 
							
								        help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--initial-prompt",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="Optional text to provide as a prompt for the first window.",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--condition-on-previous-text",
							 | 
						|||
| 
								 | 
							
								        type=str2bool,
							 | 
						|||
| 
								 | 
							
								        default=True,
							 | 
						|||
| 
								 | 
							
								        help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--fp16",
							 | 
						|||
| 
								 | 
							
								        type=str2bool,
							 | 
						|||
| 
								 | 
							
								        default=True,
							 | 
						|||
| 
								 | 
							
								        help="Whether to perform inference in fp16",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--compression-ratio-threshold",
							 | 
						|||
| 
								 | 
							
								        type=optional_float,
							 | 
						|||
| 
								 | 
							
								        default=2.4,
							 | 
						|||
| 
								 | 
							
								        help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--logprob-threshold",
							 | 
						|||
| 
								 | 
							
								        type=optional_float,
							 | 
						|||
| 
								 | 
							
								        default=-1.0,
							 | 
						|||
| 
								 | 
							
								        help="If the average log probability is lower than this value, treat the decoding as failed",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--no-speech-threshold",
							 | 
						|||
| 
								 | 
							
								        type=optional_float,
							 | 
						|||
| 
								 | 
							
								        default=0.6,
							 | 
						|||
| 
								 | 
							
								        help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--word-timestamps",
							 | 
						|||
| 
								 | 
							
								        type=str2bool,
							 | 
						|||
| 
								 | 
							
								        default=False,
							 | 
						|||
| 
								 | 
							
								        help="Extract word-level timestamps and refine the results based on them",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--prepend-punctuations",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="\"'“¿([{-",
							 | 
						|||
| 
								 | 
							
								        help="If word-timestamps is True, merge these punctuation symbols with the next word",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--append-punctuations",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="\"'.。,,!!??::”)]}、",
							 | 
						|||
| 
								 | 
							
								        help="If word_timestamps is True, merge these punctuation symbols with the previous word",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--highlight-words",
							 | 
						|||
| 
								 | 
							
								        type=str2bool,
							 | 
						|||
| 
								 | 
							
								        default=False,
							 | 
						|||
| 
								 | 
							
								        help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--max-line-width",
							 | 
						|||
| 
								 | 
							
								        type=int,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--max-line-count",
							 | 
						|||
| 
								 | 
							
								        type=int,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="(requires --word_timestamps True) the maximum number of lines in a segment",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--max-words-per-line",
							 | 
						|||
| 
								 | 
							
								        type=int,
							 | 
						|||
| 
								 | 
							
								        default=None,
							 | 
						|||
| 
								 | 
							
								        help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--hallucination-silence-threshold",
							 | 
						|||
| 
								 | 
							
								        type=optional_float,
							 | 
						|||
| 
								 | 
							
								        help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    parser.add_argument(
							 | 
						|||
| 
								 | 
							
								        "--clip-timestamps",
							 | 
						|||
| 
								 | 
							
								        type=str,
							 | 
						|||
| 
								 | 
							
								        default="0",
							 | 
						|||
| 
								 | 
							
								        help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
							 | 
						|||
| 
								 | 
							
								    )
							 | 
						|||
| 
								 | 
							
								    return parser
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								def main():
							 | 
						|||
| 
								 | 
							
								    parser = build_parser()
							 | 
						|||
| 
								 | 
							
								    args = vars(parser.parse_args())
							 | 
						|||
| 
								 | 
							
								    if args["verbose"] is True:
							 | 
						|||
| 
								 | 
							
								        print(f"Args: {args}")
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    path_or_hf_repo: str = args.pop("model")
							 | 
						|||
| 
								 | 
							
								    output_dir: str = args.pop("output_dir")
							 | 
						|||
| 
								 | 
							
								    output_format: str = args.pop("output_format")
							 | 
						|||
| 
								 | 
							
								    os.makedirs(output_dir, exist_ok=True)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								    writer = get_writer(output_format, output_dir)
							 | 
						|||
| 
								 | 
							
								    word_options = [
							 | 
						|||
| 
								 | 
							
								        "highlight_words",
							 | 
						|||
| 
								 | 
							
								        "max_line_count",
							 | 
						|||
| 
								 | 
							
								        "max_line_width",
							 | 
						|||
| 
								 | 
							
								        "max_words_per_line",
							 | 
						|||
| 
								 | 
							
								    ]
							 | 
						|||
| 
								 | 
							
								    writer_args = {arg: args.pop(arg) for arg in word_options}
							 | 
						|||
| 
								 | 
							
								    if not args["word_timestamps"]:
							 | 
						|||
| 
								 | 
							
								        for k, v in writer_args.items():
							 | 
						|||
| 
								 | 
							
								            if v:
							 | 
						|||
| 
								 | 
							
								                argop = k.replace("_", "-")
							 | 
						|||
| 
								 | 
							
								                parser.error(f"--{argop} requires --word-timestamps True")
							 | 
						|||
| 
								 | 
							
								    if writer_args["max_line_count"] and not writer_args["max_line_width"]:
							 | 
						|||
| 
								 | 
							
								        warnings.warn("--max-line-count has no effect without --max-line-width")
							 | 
						|||
| 
								 | 
							
								    if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
							 | 
						|||
| 
								 | 
							
								        warnings.warn("--max-words-per-line has no effect with --max-line-width")
							 | 
						|||
| 
								 | 
							
								    for audio_path in args.pop("audio"):
							 | 
						|||
| 
								 | 
							
								        try:
							 | 
						|||
| 
								 | 
							
								            result = transcribe(
							 | 
						|||
| 
								 | 
							
								                audio_path,
							 | 
						|||
| 
								 | 
							
								                path_or_hf_repo=path_or_hf_repo,
							 | 
						|||
| 
								 | 
							
								                **args,
							 | 
						|||
| 
								 | 
							
								            )
							 | 
						|||
| 
								 | 
							
								            writer(result, audio_path, **writer_args)
							 | 
						|||
| 
								 | 
							
								        except Exception as e:
							 | 
						|||
| 
								 | 
							
								            traceback.print_exc()
							 | 
						|||
| 
								 | 
							
								            print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						|||
| 
								 | 
							
								    main()
							 |