diff --git a/whisper/README.md b/whisper/README.md index 2805e899..ac6e95f6 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -21,6 +21,22 @@ pip install mlx-whisper ### Run +#### CLI + +At its simplest: + +``` +mlx_whisper audio_file.mp3 +``` + +This will make a text file `audio_file.txt` with the results. + +Use `-f` to specify the output format and `--model` to specify the model. There +are many other supported command line options. To see them all, run +`mlx_whisper -h`. + +#### API + Transcribe audio with: ```python diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py new file mode 100644 index 00000000..c2813338 --- /dev/null +++ b/whisper/mlx_whisper/cli.py @@ -0,0 +1,236 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import os +import traceback +import warnings + +from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE +from .transcribe import transcribe +from .writers import get_writer + + +def build_parser(): + def optional_int(string): + return None if string == "None" else int(string) + + def optional_float(string): + return None if string == "None" else float(string) + + def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "audio", nargs="+", type=str, help="Audio file(s) to transcribe" + ) + parser.add_argument( + "--model", + default="mlx-community/whisper-tiny", + type=str, + help="The model directory or hugging face repo", + ) + parser.add_argument( + "--output-dir", + "-o", + type=str, + default=".", + help="Directory to save the outputs", + ) + parser.add_argument( + "--output-format", + "-f", + type=str, + default="txt", + choices=["txt", "vtt", "srt", "tsv", "json", "all"], + help="Format of the output file", + ) + parser.add_argument( + "--verbose", + type=str2bool, + default=True, + help="Whether to print out progress and debug messages", + ) + parser.add_argument( + "--task", + type=str, + default="transcribe", + choices=["transcribe", "translate"], + help="Perform speech recognition ('transcribe') or speech translation ('translate')", + ) + parser.add_argument( + "--language", + type=str, + default=None, + choices=sorted(LANGUAGES.keys()) + + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + help="Language spoken in the audio, specify None to auto-detect", + ) + parser.add_argument( + "--temperature", type=float, default=0, help="Temperature for sampling" + ) + parser.add_argument( + "--best-of", + type=optional_int, + default=5, + help="Number of candidates when sampling with non-zero temperature", + ) + parser.add_argument( + "--patience", + type=float, + default=None, + help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search", + ) + parser.add_argument( + "--length-penalty", + type=float, + default=None, + help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.", + ) + parser.add_argument( + "--suppress-tokens", + type=str, + default="-1", + help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations", + ) + parser.add_argument( + "--initial-prompt", + type=str, + default=None, + help="Optional text to provide as a prompt for the first window.", + ) + parser.add_argument( + "--condition-on-previous-text", + type=str2bool, + default=True, + help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop", + ) + parser.add_argument( + "--fp16", + type=str2bool, + default=True, + help="Whether to perform inference in fp16", + ) + parser.add_argument( + "--compression-ratio-threshold", + type=optional_float, + default=2.4, + help="if the gzip compression ratio is higher than this value, treat the decoding as failed", + ) + parser.add_argument( + "--logprob-threshold", + type=optional_float, + default=-1.0, + help="If the average log probability is lower than this value, treat the decoding as failed", + ) + parser.add_argument( + "--no-speech-threshold", + type=optional_float, + default=0.6, + help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence", + ) + parser.add_argument( + "--word-timestamps", + type=str2bool, + default=False, + help="Extract word-level timestamps and refine the results based on them", + ) + parser.add_argument( + "--prepend-punctuations", + type=str, + default="\"'“¿([{-", + help="If word-timestamps is True, merge these punctuation symbols with the next word", + ) + parser.add_argument( + "--append-punctuations", + type=str, + default="\"'.。,,!!??::”)]}、", + help="If word_timestamps is True, merge these punctuation symbols with the previous word", + ) + parser.add_argument( + "--highlight-words", + type=str2bool, + default=False, + help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt", + ) + parser.add_argument( + "--max-line-width", + type=int, + default=None, + help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line", + ) + parser.add_argument( + "--max-line-count", + type=int, + default=None, + help="(requires --word_timestamps True) the maximum number of lines in a segment", + ) + parser.add_argument( + "--max-words-per-line", + type=int, + default=None, + help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment", + ) + parser.add_argument( + "--hallucination-silence-threshold", + type=optional_float, + help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected", + ) + parser.add_argument( + "--clip-timestamps", + type=str, + default="0", + help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file", + ) + return parser + + +def main(): + parser = build_parser() + args = vars(parser.parse_args()) + if args["verbose"] is True: + print(f"Args: {args}") + + path_or_hf_repo: str = args.pop("model") + output_dir: str = args.pop("output_dir") + output_format: str = args.pop("output_format") + os.makedirs(output_dir, exist_ok=True) + + writer = get_writer(output_format, output_dir) + word_options = [ + "highlight_words", + "max_line_count", + "max_line_width", + "max_words_per_line", + ] + writer_args = {arg: args.pop(arg) for arg in word_options} + if not args["word_timestamps"]: + for k, v in writer_args.items(): + if v: + argop = k.replace("_", "-") + parser.error(f"--{argop} requires --word-timestamps True") + if writer_args["max_line_count"] and not writer_args["max_line_width"]: + warnings.warn("--max-line-count has no effect without --max-line-width") + if writer_args["max_words_per_line"] and writer_args["max_line_width"]: + warnings.warn("--max-words-per-line has no effect with --max-line-width") + for audio_path in args.pop("audio"): + try: + result = transcribe( + audio_path, + path_or_hf_repo=path_or_hf_repo, + **args, + ) + writer(result, audio_path, **writer_args) + except Exception as e: + traceback.print_exc() + print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + + +if __name__ == "__main__": + main() diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 13c36315..04915deb 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -276,7 +276,7 @@ def add_word_timestamps( word=timing.word, start=round(time_offset + timing.start, 2), end=round(time_offset + timing.end, 2), - probability=timing.probability, + probability=float(timing.probability), ) ) diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py index ae3cfb71..67c7397c 100644 --- a/whisper/mlx_whisper/version.py +++ b/whisper/mlx_whisper/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py new file mode 100644 index 00000000..464ead18 --- /dev/null +++ b/whisper/mlx_whisper/writers.py @@ -0,0 +1,272 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +import re +import sys +import zlib +from typing import Callable, List, Optional, TextIO + + +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) + + +def get_start(segments: List[dict]) -> Optional[float]: + return next( + (w["start"] for s in segments for w in s["words"]), + segments[0]["start"] if segments else None, + ) + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__( + self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + ): + audio_basename = os.path.basename(audio_path) + audio_basename = os.path.splitext(audio_basename)[0] + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension + ) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options, **kwargs) + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) + + +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str + + def iterate_result( + self, + result: dict, + options: Optional[dict] = None, + *, + max_line_width: Optional[int] = None, + max_line_count: Optional[int] = None, + highlight_words: bool = False, + max_words_per_line: Optional[int] = None, + ): + options = options or {} + max_line_width = max_line_width or options.get("max_line_width") + max_line_count = max_line_count or options.get("max_line_count") + highlight_words = highlight_words or options.get("highlight_words", False) + max_words_per_line = max_words_per_line or options.get("max_words_per_line") + preserve_segments = max_line_count is None or max_line_width is None + max_line_width = max_line_width or 1000 + max_words_per_line = max_words_per_line or 1000 + + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: List[dict] = [] + last: float = get_start(result["segments"]) or 0.0 + for segment in result["segments"]: + chunk_index = 0 + words_count = max_words_per_line + while chunk_index < len(segment["words"]): + remaining_words = len(segment["words"]) - chunk_index + if max_words_per_line > len(segment["words"]) - chunk_index: + words_count = remaining_words + for i, original_timing in enumerate( + segment["words"][chunk_index : chunk_index + words_count] + ): + timing = original_timing.copy() + long_pause = ( + not preserve_segments and timing["start"] - last > 3.0 + ) + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if ( + line_len > 0 + and has_room + and not long_pause + and not seg_break + ): + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + chunk_index += max_words_per_line + if len(subtitle) > 0: + yield subtitle + + if len(result["segments"]) > 0 and "words" in result["segments"][0]: + for subtitle in iterate_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, "".join( + [ + ( + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + ) + for j, word in enumerate(all_words) + ] + ) + last = end + else: + yield subtitle_start, subtitle_end, subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) + + +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options, **kwargs): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options, **kwargs), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + json.dump(result, file, ensure_ascii=False) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all( + result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for writer in all_writers: + writer(result, file, options, **kwargs) + + return write_all + + return writers[output_format](output_dir) diff --git a/whisper/setup.py b/whisper/setup.py index c400a547..086f6471 100644 --- a/whisper/setup.py +++ b/whisper/setup.py @@ -29,4 +29,9 @@ setup( packages=find_namespace_packages(), include_package_data=True, python_requires=">=3.8", + entry_points={ + "console_scripts": [ + "mlx_whisper = mlx_whisper.cli:main", + ] + }, )