diff --git a/whisper/README.md b/whisper/README.md
index 2805e899..ac6e95f6 100644
--- a/whisper/README.md
+++ b/whisper/README.md
@@ -21,6 +21,22 @@ pip install mlx-whisper
### Run
+#### CLI
+
+At its simplest:
+
+```
+mlx_whisper audio_file.mp3
+```
+
+This will make a text file `audio_file.txt` with the results.
+
+Use `-f` to specify the output format and `--model` to specify the model. There
+are many other supported command line options. To see them all, run
+`mlx_whisper -h`.
+
+#### API
+
Transcribe audio with:
```python
diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py
new file mode 100644
index 00000000..c2813338
--- /dev/null
+++ b/whisper/mlx_whisper/cli.py
@@ -0,0 +1,236 @@
+# Copyright © 2024 Apple Inc.
+
+import argparse
+import os
+import traceback
+import warnings
+
+from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
+from .transcribe import transcribe
+from .writers import get_writer
+
+
+def build_parser():
+ def optional_int(string):
+ return None if string == "None" else int(string)
+
+ def optional_float(string):
+ return None if string == "None" else float(string)
+
+ def str2bool(string):
+ str2val = {"True": True, "False": False}
+ if string in str2val:
+ return str2val[string]
+ else:
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "audio", nargs="+", type=str, help="Audio file(s) to transcribe"
+ )
+ parser.add_argument(
+ "--model",
+ default="mlx-community/whisper-tiny",
+ type=str,
+ help="The model directory or hugging face repo",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ type=str,
+ default=".",
+ help="Directory to save the outputs",
+ )
+ parser.add_argument(
+ "--output-format",
+ "-f",
+ type=str,
+ default="txt",
+ choices=["txt", "vtt", "srt", "tsv", "json", "all"],
+ help="Format of the output file",
+ )
+ parser.add_argument(
+ "--verbose",
+ type=str2bool,
+ default=True,
+ help="Whether to print out progress and debug messages",
+ )
+ parser.add_argument(
+ "--task",
+ type=str,
+ default="transcribe",
+ choices=["transcribe", "translate"],
+ help="Perform speech recognition ('transcribe') or speech translation ('translate')",
+ )
+ parser.add_argument(
+ "--language",
+ type=str,
+ default=None,
+ choices=sorted(LANGUAGES.keys())
+ + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
+ help="Language spoken in the audio, specify None to auto-detect",
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0, help="Temperature for sampling"
+ )
+ parser.add_argument(
+ "--best-of",
+ type=optional_int,
+ default=5,
+ help="Number of candidates when sampling with non-zero temperature",
+ )
+ parser.add_argument(
+ "--patience",
+ type=float,
+ default=None,
+ help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
+ )
+ parser.add_argument(
+ "--length-penalty",
+ type=float,
+ default=None,
+ help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.",
+ )
+ parser.add_argument(
+ "--suppress-tokens",
+ type=str,
+ default="-1",
+ help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
+ )
+ parser.add_argument(
+ "--initial-prompt",
+ type=str,
+ default=None,
+ help="Optional text to provide as a prompt for the first window.",
+ )
+ parser.add_argument(
+ "--condition-on-previous-text",
+ type=str2bool,
+ default=True,
+ help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
+ )
+ parser.add_argument(
+ "--fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to perform inference in fp16",
+ )
+ parser.add_argument(
+ "--compression-ratio-threshold",
+ type=optional_float,
+ default=2.4,
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
+ )
+ parser.add_argument(
+ "--logprob-threshold",
+ type=optional_float,
+ default=-1.0,
+ help="If the average log probability is lower than this value, treat the decoding as failed",
+ )
+ parser.add_argument(
+ "--no-speech-threshold",
+ type=optional_float,
+ default=0.6,
+ help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence",
+ )
+ parser.add_argument(
+ "--word-timestamps",
+ type=str2bool,
+ default=False,
+ help="Extract word-level timestamps and refine the results based on them",
+ )
+ parser.add_argument(
+ "--prepend-punctuations",
+ type=str,
+ default="\"'“¿([{-",
+ help="If word-timestamps is True, merge these punctuation symbols with the next word",
+ )
+ parser.add_argument(
+ "--append-punctuations",
+ type=str,
+ default="\"'.。,,!!??::”)]}、",
+ help="If word_timestamps is True, merge these punctuation symbols with the previous word",
+ )
+ parser.add_argument(
+ "--highlight-words",
+ type=str2bool,
+ default=False,
+ help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
+ )
+ parser.add_argument(
+ "--max-line-width",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
+ )
+ parser.add_argument(
+ "--max-line-count",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True) the maximum number of lines in a segment",
+ )
+ parser.add_argument(
+ "--max-words-per-line",
+ type=int,
+ default=None,
+ help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
+ )
+ parser.add_argument(
+ "--hallucination-silence-threshold",
+ type=optional_float,
+ help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
+ )
+ parser.add_argument(
+ "--clip-timestamps",
+ type=str,
+ default="0",
+ help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
+ )
+ return parser
+
+
+def main():
+ parser = build_parser()
+ args = vars(parser.parse_args())
+ if args["verbose"] is True:
+ print(f"Args: {args}")
+
+ path_or_hf_repo: str = args.pop("model")
+ output_dir: str = args.pop("output_dir")
+ output_format: str = args.pop("output_format")
+ os.makedirs(output_dir, exist_ok=True)
+
+ writer = get_writer(output_format, output_dir)
+ word_options = [
+ "highlight_words",
+ "max_line_count",
+ "max_line_width",
+ "max_words_per_line",
+ ]
+ writer_args = {arg: args.pop(arg) for arg in word_options}
+ if not args["word_timestamps"]:
+ for k, v in writer_args.items():
+ if v:
+ argop = k.replace("_", "-")
+ parser.error(f"--{argop} requires --word-timestamps True")
+ if writer_args["max_line_count"] and not writer_args["max_line_width"]:
+ warnings.warn("--max-line-count has no effect without --max-line-width")
+ if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
+ warnings.warn("--max-words-per-line has no effect with --max-line-width")
+ for audio_path in args.pop("audio"):
+ try:
+ result = transcribe(
+ audio_path,
+ path_or_hf_repo=path_or_hf_repo,
+ **args,
+ )
+ writer(result, audio_path, **writer_args)
+ except Exception as e:
+ traceback.print_exc()
+ print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py
index 13c36315..04915deb 100644
--- a/whisper/mlx_whisper/timing.py
+++ b/whisper/mlx_whisper/timing.py
@@ -276,7 +276,7 @@ def add_word_timestamps(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
- probability=timing.probability,
+ probability=float(timing.probability),
)
)
diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py
index ae3cfb71..67c7397c 100644
--- a/whisper/mlx_whisper/version.py
+++ b/whisper/mlx_whisper/version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.2.0"
+__version__ = "0.3.0"
diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py
new file mode 100644
index 00000000..464ead18
--- /dev/null
+++ b/whisper/mlx_whisper/writers.py
@@ -0,0 +1,272 @@
+# Copyright © 2024 Apple Inc.
+
+import json
+import os
+import re
+import sys
+import zlib
+from typing import Callable, List, Optional, TextIO
+
+
+def format_timestamp(
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
+):
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return (
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+ )
+
+
+def get_start(segments: List[dict]) -> Optional[float]:
+ return next(
+ (w["start"] for s in segments for w in s["words"]),
+ segments[0]["start"] if segments else None,
+ )
+
+
+class ResultWriter:
+ extension: str
+
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+
+ def __call__(
+ self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
+ ):
+ audio_basename = os.path.basename(audio_path)
+ audio_basename = os.path.splitext(audio_basename)[0]
+ output_path = os.path.join(
+ self.output_dir, audio_basename + "." + self.extension
+ )
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ self.write_result(result, file=f, options=options, **kwargs)
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ raise NotImplementedError
+
+
+class WriteTXT(ResultWriter):
+ extension: str = "txt"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for segment in result["segments"]:
+ print(segment["text"].strip(), file=file, flush=True)
+
+
+class SubtitlesWriter(ResultWriter):
+ always_include_hours: bool
+ decimal_marker: str
+
+ def iterate_result(
+ self,
+ result: dict,
+ options: Optional[dict] = None,
+ *,
+ max_line_width: Optional[int] = None,
+ max_line_count: Optional[int] = None,
+ highlight_words: bool = False,
+ max_words_per_line: Optional[int] = None,
+ ):
+ options = options or {}
+ max_line_width = max_line_width or options.get("max_line_width")
+ max_line_count = max_line_count or options.get("max_line_count")
+ highlight_words = highlight_words or options.get("highlight_words", False)
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
+ preserve_segments = max_line_count is None or max_line_width is None
+ max_line_width = max_line_width or 1000
+ max_words_per_line = max_words_per_line or 1000
+
+ def iterate_subtitles():
+ line_len = 0
+ line_count = 1
+ # the next subtitle to yield (a list of word timings with whitespace)
+ subtitle: List[dict] = []
+ last: float = get_start(result["segments"]) or 0.0
+ for segment in result["segments"]:
+ chunk_index = 0
+ words_count = max_words_per_line
+ while chunk_index < len(segment["words"]):
+ remaining_words = len(segment["words"]) - chunk_index
+ if max_words_per_line > len(segment["words"]) - chunk_index:
+ words_count = remaining_words
+ for i, original_timing in enumerate(
+ segment["words"][chunk_index : chunk_index + words_count]
+ ):
+ timing = original_timing.copy()
+ long_pause = (
+ not preserve_segments and timing["start"] - last > 3.0
+ )
+ has_room = line_len + len(timing["word"]) <= max_line_width
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
+ if (
+ line_len > 0
+ and has_room
+ and not long_pause
+ and not seg_break
+ ):
+ # line continuation
+ line_len += len(timing["word"])
+ else:
+ # new line
+ timing["word"] = timing["word"].strip()
+ if (
+ len(subtitle) > 0
+ and max_line_count is not None
+ and (long_pause or line_count >= max_line_count)
+ or seg_break
+ ):
+ # subtitle break
+ yield subtitle
+ subtitle = []
+ line_count = 1
+ elif line_len > 0:
+ # line break
+ line_count += 1
+ timing["word"] = "\n" + timing["word"]
+ line_len = len(timing["word"].strip())
+ subtitle.append(timing)
+ last = timing["start"]
+ chunk_index += max_words_per_line
+ if len(subtitle) > 0:
+ yield subtitle
+
+ if len(result["segments"]) > 0 and "words" in result["segments"][0]:
+ for subtitle in iterate_subtitles():
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
+ subtitle_text = "".join([word["word"] for word in subtitle])
+ if highlight_words:
+ last = subtitle_start
+ all_words = [timing["word"] for timing in subtitle]
+ for i, this_word in enumerate(subtitle):
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
+
+ yield start, end, "".join(
+ [
+ (
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ )
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
+ else:
+ yield subtitle_start, subtitle_end, subtitle_text
+ else:
+ for segment in result["segments"]:
+ segment_start = self.format_timestamp(segment["start"])
+ segment_end = self.format_timestamp(segment["end"])
+ segment_text = segment["text"].strip().replace("-->", "->")
+ yield segment_start, segment_end, segment_text
+
+ def format_timestamp(self, seconds: float):
+ return format_timestamp(
+ seconds=seconds,
+ always_include_hours=self.always_include_hours,
+ decimal_marker=self.decimal_marker,
+ )
+
+
+class WriteVTT(SubtitlesWriter):
+ extension: str = "vtt"
+ always_include_hours: bool = False
+ decimal_marker: str = "."
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("WEBVTT\n", file=file)
+ for start, end, text in self.iterate_result(result, options, **kwargs):
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteSRT(SubtitlesWriter):
+ extension: str = "srt"
+ always_include_hours: bool = True
+ decimal_marker: str = ","
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options, **kwargs), start=1
+ ):
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
+
+
+class WriteTSV(ResultWriter):
+ """
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
+ \t\t
+
+ Using integer milliseconds as start and end times means there's no chance of interference from
+ an environment setting a language encoding that causes the decimal in a floating point number
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
+ """
+
+ extension: str = "tsv"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in result["segments"]:
+ print(round(1000 * segment["start"]), file=file, end="\t")
+ print(round(1000 * segment["end"]), file=file, end="\t")
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
+
+
+class WriteJSON(ResultWriter):
+ extension: str = "json"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ json.dump(result, file, ensure_ascii=False)
+
+
+def get_writer(
+ output_format: str, output_dir: str
+) -> Callable[[dict, TextIO, dict], None]:
+ writers = {
+ "txt": WriteTXT,
+ "vtt": WriteVTT,
+ "srt": WriteSRT,
+ "tsv": WriteTSV,
+ "json": WriteJSON,
+ }
+
+ if output_format == "all":
+ all_writers = [writer(output_dir) for writer in writers.values()]
+
+ def write_all(
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for writer in all_writers:
+ writer(result, file, options, **kwargs)
+
+ return write_all
+
+ return writers[output_format](output_dir)
diff --git a/whisper/setup.py b/whisper/setup.py
index c400a547..086f6471 100644
--- a/whisper/setup.py
+++ b/whisper/setup.py
@@ -29,4 +29,9 @@ setup(
packages=find_namespace_packages(),
include_package_data=True,
python_requires=">=3.8",
+ entry_points={
+ "console_scripts": [
+ "mlx_whisper = mlx_whisper.cli:main",
+ ]
+ },
)