refactored to stdin - arg, and output-name template

This commit is contained in:
Anthony Wu 2024-10-03 15:03:31 -10:00
parent bb5d7db5d7
commit b6435dc9cc
3 changed files with 47 additions and 63 deletions

View File

@ -2,7 +2,6 @@
import argparse import argparse
import os import os
import sys
import traceback import traceback
import warnings import warnings
@ -12,7 +11,7 @@ from .transcribe import transcribe
from .writers import get_writer from .writers import get_writer
def build_parser(is_audio_from_stdin=False): def build_parser():
def optional_int(string): def optional_int(string):
return None if string == "None" else int(string) return None if string == "None" else int(string)
@ -30,8 +29,7 @@ def build_parser(is_audio_from_stdin=False):
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
if not is_audio_from_stdin: parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument( parser.add_argument(
"--model", "--model",
@ -40,10 +38,10 @@ def build_parser(is_audio_from_stdin=False):
help="The model directory or hugging face repo", help="The model directory or hugging face repo",
) )
parser.add_argument( parser.add_argument(
"--input-name", "--output-name",
type=str, type=str,
default="content", default="{basename}",
help="logical name of audio content received via stdin", help="logical name of transcription/translation output files, before --output-format extensions",
) )
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
@ -201,8 +199,7 @@ def build_parser(is_audio_from_stdin=False):
def main(): def main():
is_audio_from_stdin = not os.isatty(sys.stdin.fileno()) parser = build_parser()
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
if args["verbose"] is True: if args["verbose"] is True:
print(f"Args: {args}") print(f"Args: {args}")
@ -210,9 +207,10 @@ def main():
path_or_hf_repo: str = args.pop("model") path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
output_name_template: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir, output_name_template)
word_options = [ word_options = [
"highlight_words", "highlight_words",
"max_line_count", "max_line_count",
@ -230,23 +228,19 @@ def main():
if writer_args["max_words_per_line"] and writer_args["max_line_width"]: if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width") warnings.warn("--max-words-per-line has no effect with --max-line-width")
if is_audio_from_stdin: for audio_obj in args.pop("audio"):
audio_list = [audio.load_audio(from_stdin=True)] if audio_obj == "-":
input_name = args.pop("input_name") # receive the contents from stdin rather than read a file
else: audio_obj = audio.load_audio(from_stdin=True)
audio_list = args.pop("audio") output_name_template = "content"
args.pop("input_name")
for audio_obj in audio_list:
try: try:
result = transcribe( result = transcribe(
audio_obj, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
if not is_audio_from_stdin: writer(result, audio_obj, **writer_args)
input_name = audio_obj
writer(result, input_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import json import json
import os import pathlib
import re import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO from typing import Callable, List, Optional, TextIO
@ -39,19 +37,26 @@ def get_start(segments: List[dict]) -> Optional[float]:
class ResultWriter: class ResultWriter:
extension: str extension: str
def __init__(self, output_dir: str): def __init__(self, output_dir: str, output_name_template: str):
self.output_dir = output_dir self.output_dir = output_dir
self.output_name_template = output_name_template
def __call__( def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs
): ):
audio_basename = os.path.basename(audio_path) if isinstance(audio_obj, (str, pathlib.Path)):
audio_basename = os.path.splitext(audio_basename)[0] basename = pathlib.Path(audio_obj).stem
output_path = os.path.join( else:
self.output_dir, audio_basename + "." + self.extension # mx.array, np.ndarray, etc
basename = "content"
output_basename = self.output_name_template.format(basename=basename)
output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix(
f".{self.extension}"
) )
with open(output_path, "w", encoding="utf-8") as f: with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs) self.write_result(result, file=f, options=options, **kwargs)
def write_result( def write_result(
@ -248,7 +253,7 @@ class WriteJSON(ResultWriter):
def get_writer( def get_writer(
output_format: str, output_dir: str output_format: str, output_dir: str, output_name_template: str
) -> Callable[[dict, TextIO, dict], None]: ) -> Callable[[dict, TextIO, dict], None]:
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
@ -259,7 +264,9 @@ def get_writer(
} }
if output_format == "all": if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()] all_writers = [
writer(output_dir, output_name_template) for writer in writers.values()
]
def write_all( def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
@ -269,4 +276,4 @@ def get_writer(
return write_all return write_all
return writers[output_format](output_dir) return writers[output_format](output_dir, output_name_template)

View File

@ -3,40 +3,21 @@
set -o err_exit set -o err_exit
TEST_AUDIO="mlx_whisper/assets/ls_test.flac" TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
# when not receiving stdin, check audio arg is required
TEST_1="mlx_whisper requires audio position arg when not provided with stdin"
if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then
echo "[PASS] $TEST_1"
else
echo "[FAIL] $TEST_1"
fi
TEST_2="mlx_whisper does not require audio position arg when provided with stdin"
if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then
echo "[PASS] $TEST_2"
else
echo "[FAIL] $TEST_2"
fi
TEST_3="mlx_whisper accepts optional --input-name arg"
if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then
echo "[PASS] $TEST_3"
else
echo "[FAIL] $TEST_3"
fi
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test) TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
# the control output - cli called with audio position arg # the control output - cli called with audio position arg
# expected output file name is ls_test.json # expected output file name is ls_test.json
mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False mlx_whisper "$TEST_AUDIO" \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format all \
--output-name '{basename}_transcribed' \
--temperature 0 \
--verbose=False
/bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} | sort
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content" TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
# method stdin - output file is content.json (default --input-name is content when not provided) /bin/cat "$TEST_AUDIO" | mlx_whisper - --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then
echo "[PASS] $TEST_STDIN_1" echo "[PASS] $TEST_STDIN_1"
else else
echo "[FAIL] $TEST_STDIN_1" echo "[FAIL] $TEST_STDIN_1"
@ -44,10 +25,12 @@ else
fi fi
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file" TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"
mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO" mlx_whisper - --output-name '{basename}_transcribed' --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then
echo "[PASS] $TEST_STDIN_2" echo "[PASS] $TEST_STDIN_2"
else else
echo "[FAIL] $TEST_STDIN_2" echo "[FAIL] $TEST_STDIN_2"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}" echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi fi
echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"