From b6435dc9cca44ee46769d789cb779b55ec83da6c Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:03:31 -1000 Subject: [PATCH] refactored to stdin - arg, and output-name template --- whisper/mlx_whisper/cli.py | 34 +++++++++++---------------- whisper/mlx_whisper/writers.py | 33 ++++++++++++++++---------- whisper/test_cli.sh | 43 ++++++++++------------------------ 3 files changed, 47 insertions(+), 63 deletions(-) diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index 98c7fd80..f5304e5d 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,7 +2,6 @@ import argparse import os -import sys import traceback import warnings @@ -12,7 +11,7 @@ from .transcribe import transcribe from .writers import get_writer -def build_parser(is_audio_from_stdin=False): +def build_parser(): def optional_int(string): return None if string == "None" else int(string) @@ -30,8 +29,7 @@ def build_parser(is_audio_from_stdin=False): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - if not is_audio_from_stdin: - parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") parser.add_argument( "--model", @@ -40,10 +38,10 @@ def build_parser(is_audio_from_stdin=False): help="The model directory or hugging face repo", ) parser.add_argument( - "--input-name", + "--output-name", type=str, - default="content", - help="logical name of audio content received via stdin", + default="{basename}", + help="logical name of transcription/translation output files, before --output-format extensions", ) parser.add_argument( "--output-dir", @@ -201,8 +199,7 @@ def build_parser(is_audio_from_stdin=False): def main(): - is_audio_from_stdin = not os.isatty(sys.stdin.fileno()) - parser = build_parser(is_audio_from_stdin=is_audio_from_stdin) + parser = build_parser() args = vars(parser.parse_args()) if args["verbose"] is True: print(f"Args: {args}") @@ -210,9 +207,10 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name_template: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) - writer = get_writer(output_format, output_dir) + writer = get_writer(output_format, output_dir, output_name_template) word_options = [ "highlight_words", "max_line_count", @@ -230,23 +228,19 @@ def main(): if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - if is_audio_from_stdin: - audio_list = [audio.load_audio(from_stdin=True)] - input_name = args.pop("input_name") - else: - audio_list = args.pop("audio") - args.pop("input_name") + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + output_name_template = "content" - for audio_obj in audio_list: try: result = transcribe( audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - if not is_audio_from_stdin: - input_name = audio_obj - writer(result, input_name, **writer_args) + writer(result, audio_obj, **writer_args) except Exception as e: traceback.print_exc() print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead18..cbfe1f66 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -39,19 +37,26 @@ def get_start(segments: List[dict]) -> Optional[float]: class ResultWriter: extension: str - def __init__(self, output_dir: str): + def __init__(self, output_dir: str, output_name_template: str): self.output_dir = output_dir + self.output_name_template = output_name_template def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + if isinstance(audio_obj, (str, pathlib.Path)): + basename = pathlib.Path(audio_obj).stem + else: + # mx.array, np.ndarray, etc + basename = "content" + + output_basename = self.output_name_template.format(basename=basename) + + output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result( @@ -248,7 +253,7 @@ class WriteJSON(ResultWriter): def get_writer( - output_format: str, output_dir: str + output_format: str, output_dir: str, output_name_template: str ) -> Callable[[dict, TextIO, dict], None]: writers = { "txt": WriteTXT, @@ -259,7 +264,9 @@ def get_writer( } if output_format == "all": - all_writers = [writer(output_dir) for writer in writers.values()] + all_writers = [ + writer(output_dir, output_name_template) for writer in writers.values() + ] def write_all( result: dict, file: TextIO, options: Optional[dict] = None, **kwargs @@ -269,4 +276,4 @@ def get_writer( return write_all - return writers[output_format](output_dir) + return writers[output_format](output_dir, output_name_template) diff --git a/whisper/test_cli.sh b/whisper/test_cli.sh index 78a12a3f..f3b9b7c9 100755 --- a/whisper/test_cli.sh +++ b/whisper/test_cli.sh @@ -3,40 +3,21 @@ set -o err_exit TEST_AUDIO="mlx_whisper/assets/ls_test.flac" - -# when not receiving stdin, check audio arg is required -TEST_1="mlx_whisper requires audio position arg when not provided with stdin" -if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then - echo "[PASS] $TEST_1" -else - echo "[FAIL] $TEST_1" -fi - -TEST_2="mlx_whisper does not require audio position arg when provided with stdin" -if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then - echo "[PASS] $TEST_2" -else - echo "[FAIL] $TEST_2" -fi - -TEST_3="mlx_whisper accepts optional --input-name arg" -if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then - echo "[PASS] $TEST_3" -else - echo "[FAIL] $TEST_3" -fi - TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test) # the control output - cli called with audio position arg # expected output file name is ls_test.json -mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False - +mlx_whisper "$TEST_AUDIO" \ + --output-dir "$TEST_OUTPUT_DIR" \ + --output-format all \ + --output-name '{basename}_transcribed' \ + --temperature 0 \ + --verbose=False +/bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} | sort TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content" -# method stdin - output file is content.json (default --input-name is content when not provided) -/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False -if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then +/bin/cat "$TEST_AUDIO" | mlx_whisper - --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False +if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then echo "[PASS] $TEST_STDIN_1" else echo "[FAIL] $TEST_STDIN_1" @@ -44,10 +25,12 @@ else fi TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"、 -mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO" -if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then +mlx_whisper - --output-name '{basename}_transcribed' --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO" +if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then echo "[PASS] $TEST_STDIN_2" else echo "[FAIL] $TEST_STDIN_2" echo "Check unexpected output in ${TEST_OUTPUT_DIR}" fi + +echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"