refactored to stdin - arg, and output-name template

This commit is contained in:
Anthony Wu 2024-10-03 15:03:31 -10:00
parent bb5d7db5d7
commit b6435dc9cc
3 changed files with 47 additions and 63 deletions

View File

@ -2,7 +2,6 @@
import argparse
import os
import sys
import traceback
import warnings
@ -12,7 +11,7 @@ from .transcribe import transcribe
from .writers import get_writer
def build_parser(is_audio_from_stdin=False):
def build_parser():
def optional_int(string):
return None if string == "None" else int(string)
@ -30,8 +29,7 @@ def build_parser(is_audio_from_stdin=False):
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
if not is_audio_from_stdin:
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument(
"--model",
@ -40,10 +38,10 @@ def build_parser(is_audio_from_stdin=False):
help="The model directory or hugging face repo",
)
parser.add_argument(
"--input-name",
"--output-name",
type=str,
default="content",
help="logical name of audio content received via stdin",
default="{basename}",
help="logical name of transcription/translation output files, before --output-format extensions",
)
parser.add_argument(
"--output-dir",
@ -201,8 +199,7 @@ def build_parser(is_audio_from_stdin=False):
def main():
is_audio_from_stdin = not os.isatty(sys.stdin.fileno())
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
parser = build_parser()
args = vars(parser.parse_args())
if args["verbose"] is True:
print(f"Args: {args}")
@ -210,9 +207,10 @@ def main():
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
output_name_template: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
writer = get_writer(output_format, output_dir, output_name_template)
word_options = [
"highlight_words",
"max_line_count",
@ -230,23 +228,19 @@ def main():
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
if is_audio_from_stdin:
audio_list = [audio.load_audio(from_stdin=True)]
input_name = args.pop("input_name")
else:
audio_list = args.pop("audio")
args.pop("input_name")
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name_template = "content"
for audio_obj in audio_list:
try:
result = transcribe(
audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
if not is_audio_from_stdin:
input_name = audio_obj
writer(result, input_name, **writer_args)
writer(result, audio_obj, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc.
import json
import os
import pathlib
import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO
@ -39,19 +37,26 @@ def get_start(segments: List[dict]) -> Optional[float]:
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
def __init__(self, output_dir: str, output_name_template: str):
self.output_dir = output_dir
self.output_name_template = output_name_template
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
if isinstance(audio_obj, (str, pathlib.Path)):
basename = pathlib.Path(audio_obj).stem
else:
# mx.array, np.ndarray, etc
basename = "content"
output_basename = self.output_name_template.format(basename=basename)
output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix(
f".{self.extension}"
)
with open(output_path, "w", encoding="utf-8") as f:
with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(
@ -248,7 +253,7 @@ class WriteJSON(ResultWriter):
def get_writer(
output_format: str, output_dir: str
output_format: str, output_dir: str, output_name_template: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
@ -259,7 +264,9 @@ def get_writer(
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
all_writers = [
writer(output_dir, output_name_template) for writer in writers.values()
]
def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
@ -269,4 +276,4 @@ def get_writer(
return write_all
return writers[output_format](output_dir)
return writers[output_format](output_dir, output_name_template)

View File

@ -3,40 +3,21 @@
set -o err_exit
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
# when not receiving stdin, check audio arg is required
TEST_1="mlx_whisper requires audio position arg when not provided with stdin"
if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then
echo "[PASS] $TEST_1"
else
echo "[FAIL] $TEST_1"
fi
TEST_2="mlx_whisper does not require audio position arg when provided with stdin"
if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then
echo "[PASS] $TEST_2"
else
echo "[FAIL] $TEST_2"
fi
TEST_3="mlx_whisper accepts optional --input-name arg"
if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then
echo "[PASS] $TEST_3"
else
echo "[FAIL] $TEST_3"
fi
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
# the control output - cli called with audio position arg
# expected output file name is ls_test.json
mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False
mlx_whisper "$TEST_AUDIO" \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format all \
--output-name '{basename}_transcribed' \
--temperature 0 \
--verbose=False
/bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} | sort
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
# method stdin - output file is content.json (default --input-name is content when not provided)
/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then
/bin/cat "$TEST_AUDIO" | mlx_whisper - --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then
echo "[PASS] $TEST_STDIN_1"
else
echo "[FAIL] $TEST_STDIN_1"
@ -44,10 +25,12 @@ else
fi
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"
mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then
mlx_whisper - --output-name '{basename}_transcribed' --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then
echo "[PASS] $TEST_STDIN_2"
else
echo "[FAIL] $TEST_STDIN_2"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi
echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"