mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
refactored to stdin - arg, and output-name template
This commit is contained in:
parent
bb5d7db5d7
commit
b6435dc9cc
@ -2,7 +2,6 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
@ -12,7 +11,7 @@ from .transcribe import transcribe
|
||||
from .writers import get_writer
|
||||
|
||||
|
||||
def build_parser(is_audio_from_stdin=False):
|
||||
def build_parser():
|
||||
def optional_int(string):
|
||||
return None if string == "None" else int(string)
|
||||
|
||||
@ -30,8 +29,7 @@ def build_parser(is_audio_from_stdin=False):
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
if not is_audio_from_stdin:
|
||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@ -40,10 +38,10 @@ def build_parser(is_audio_from_stdin=False):
|
||||
help="The model directory or hugging face repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-name",
|
||||
"--output-name",
|
||||
type=str,
|
||||
default="content",
|
||||
help="logical name of audio content received via stdin",
|
||||
default="{basename}",
|
||||
help="logical name of transcription/translation output files, before --output-format extensions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
@ -201,8 +199,7 @@ def build_parser(is_audio_from_stdin=False):
|
||||
|
||||
|
||||
def main():
|
||||
is_audio_from_stdin = not os.isatty(sys.stdin.fileno())
|
||||
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
|
||||
parser = build_parser()
|
||||
args = vars(parser.parse_args())
|
||||
if args["verbose"] is True:
|
||||
print(f"Args: {args}")
|
||||
@ -210,9 +207,10 @@ def main():
|
||||
path_or_hf_repo: str = args.pop("model")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
output_name_template: str = args.pop("output_name")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
writer = get_writer(output_format, output_dir, output_name_template)
|
||||
word_options = [
|
||||
"highlight_words",
|
||||
"max_line_count",
|
||||
@ -230,23 +228,19 @@ def main():
|
||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||
|
||||
if is_audio_from_stdin:
|
||||
audio_list = [audio.load_audio(from_stdin=True)]
|
||||
input_name = args.pop("input_name")
|
||||
else:
|
||||
audio_list = args.pop("audio")
|
||||
args.pop("input_name")
|
||||
for audio_obj in args.pop("audio"):
|
||||
if audio_obj == "-":
|
||||
# receive the contents from stdin rather than read a file
|
||||
audio_obj = audio.load_audio(from_stdin=True)
|
||||
output_name_template = "content"
|
||||
|
||||
for audio_obj in audio_list:
|
||||
try:
|
||||
result = transcribe(
|
||||
audio_obj,
|
||||
path_or_hf_repo=path_or_hf_repo,
|
||||
**args,
|
||||
)
|
||||
if not is_audio_from_stdin:
|
||||
input_name = audio_obj
|
||||
writer(result, input_name, **writer_args)
|
||||
writer(result, audio_obj, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||
|
@ -1,10 +1,8 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
|
||||
@ -39,19 +37,26 @@ def get_start(segments: List[dict]) -> Optional[float]:
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
def __init__(self, output_dir: str, output_name_template: str):
|
||||
self.output_dir = output_dir
|
||||
self.output_name_template = output_name_template
|
||||
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
if isinstance(audio_obj, (str, pathlib.Path)):
|
||||
basename = pathlib.Path(audio_obj).stem
|
||||
else:
|
||||
# mx.array, np.ndarray, etc
|
||||
basename = "content"
|
||||
|
||||
output_basename = self.output_name_template.format(basename=basename)
|
||||
|
||||
output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix(
|
||||
f".{self.extension}"
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
with output_path.open("wt", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(
|
||||
@ -248,7 +253,7 @@ class WriteJSON(ResultWriter):
|
||||
|
||||
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
output_format: str, output_dir: str, output_name_template: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
@ -259,7 +264,9 @@ def get_writer(
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
all_writers = [
|
||||
writer(output_dir, output_name_template) for writer in writers.values()
|
||||
]
|
||||
|
||||
def write_all(
|
||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
@ -269,4 +276,4 @@ def get_writer(
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
return writers[output_format](output_dir, output_name_template)
|
||||
|
@ -3,40 +3,21 @@
|
||||
set -o err_exit
|
||||
|
||||
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
|
||||
|
||||
# when not receiving stdin, check audio arg is required
|
||||
TEST_1="mlx_whisper requires audio position arg when not provided with stdin"
|
||||
if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then
|
||||
echo "[PASS] $TEST_1"
|
||||
else
|
||||
echo "[FAIL] $TEST_1"
|
||||
fi
|
||||
|
||||
TEST_2="mlx_whisper does not require audio position arg when provided with stdin"
|
||||
if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then
|
||||
echo "[PASS] $TEST_2"
|
||||
else
|
||||
echo "[FAIL] $TEST_2"
|
||||
fi
|
||||
|
||||
TEST_3="mlx_whisper accepts optional --input-name arg"
|
||||
if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then
|
||||
echo "[PASS] $TEST_3"
|
||||
else
|
||||
echo "[FAIL] $TEST_3"
|
||||
fi
|
||||
|
||||
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
|
||||
|
||||
# the control output - cli called with audio position arg
|
||||
# expected output file name is ls_test.json
|
||||
mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False
|
||||
|
||||
mlx_whisper "$TEST_AUDIO" \
|
||||
--output-dir "$TEST_OUTPUT_DIR" \
|
||||
--output-format all \
|
||||
--output-name '{basename}_transcribed' \
|
||||
--temperature 0 \
|
||||
--verbose=False
|
||||
/bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} | sort
|
||||
|
||||
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
|
||||
# method stdin - output file is content.json (default --input-name is content when not provided)
|
||||
/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
|
||||
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then
|
||||
/bin/cat "$TEST_AUDIO" | mlx_whisper - --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
|
||||
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then
|
||||
echo "[PASS] $TEST_STDIN_1"
|
||||
else
|
||||
echo "[FAIL] $TEST_STDIN_1"
|
||||
@ -44,10 +25,12 @@ else
|
||||
fi
|
||||
|
||||
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"、
|
||||
mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
|
||||
if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then
|
||||
mlx_whisper - --output-name '{basename}_transcribed' --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
|
||||
if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then
|
||||
echo "[PASS] $TEST_STDIN_2"
|
||||
else
|
||||
echo "[FAIL] $TEST_STDIN_2"
|
||||
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
|
||||
fi
|
||||
|
||||
echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"
|
||||
|
Loading…
Reference in New Issue
Block a user