mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 03:19:23 +08:00
refactored to stdin - arg, and output-name template
This commit is contained in:
parent
bb5d7db5d7
commit
b6435dc9cc
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ from .transcribe import transcribe
|
|||||||
from .writers import get_writer
|
from .writers import get_writer
|
||||||
|
|
||||||
|
|
||||||
def build_parser(is_audio_from_stdin=False):
|
def build_parser():
|
||||||
def optional_int(string):
|
def optional_int(string):
|
||||||
return None if string == "None" else int(string)
|
return None if string == "None" else int(string)
|
||||||
|
|
||||||
@ -30,8 +29,7 @@ def build_parser(is_audio_from_stdin=False):
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_audio_from_stdin:
|
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@ -40,10 +38,10 @@ def build_parser(is_audio_from_stdin=False):
|
|||||||
help="The model directory or hugging face repo",
|
help="The model directory or hugging face repo",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-name",
|
"--output-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="content",
|
default="{basename}",
|
||||||
help="logical name of audio content received via stdin",
|
help="logical name of transcription/translation output files, before --output-format extensions",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
@ -201,8 +199,7 @@ def build_parser(is_audio_from_stdin=False):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
is_audio_from_stdin = not os.isatty(sys.stdin.fileno())
|
parser = build_parser()
|
||||||
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
|
|
||||||
args = vars(parser.parse_args())
|
args = vars(parser.parse_args())
|
||||||
if args["verbose"] is True:
|
if args["verbose"] is True:
|
||||||
print(f"Args: {args}")
|
print(f"Args: {args}")
|
||||||
@ -210,9 +207,10 @@ def main():
|
|||||||
path_or_hf_repo: str = args.pop("model")
|
path_or_hf_repo: str = args.pop("model")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
|
output_name_template: str = args.pop("output_name")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir, output_name_template)
|
||||||
word_options = [
|
word_options = [
|
||||||
"highlight_words",
|
"highlight_words",
|
||||||
"max_line_count",
|
"max_line_count",
|
||||||
@ -230,23 +228,19 @@ def main():
|
|||||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||||
|
|
||||||
if is_audio_from_stdin:
|
for audio_obj in args.pop("audio"):
|
||||||
audio_list = [audio.load_audio(from_stdin=True)]
|
if audio_obj == "-":
|
||||||
input_name = args.pop("input_name")
|
# receive the contents from stdin rather than read a file
|
||||||
else:
|
audio_obj = audio.load_audio(from_stdin=True)
|
||||||
audio_list = args.pop("audio")
|
output_name_template = "content"
|
||||||
args.pop("input_name")
|
|
||||||
|
|
||||||
for audio_obj in audio_list:
|
|
||||||
try:
|
try:
|
||||||
result = transcribe(
|
result = transcribe(
|
||||||
audio_obj,
|
audio_obj,
|
||||||
path_or_hf_repo=path_or_hf_repo,
|
path_or_hf_repo=path_or_hf_repo,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
if not is_audio_from_stdin:
|
writer(result, audio_obj, **writer_args)
|
||||||
input_name = audio_obj
|
|
||||||
writer(result, input_name, **writer_args)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import zlib
|
|
||||||
from typing import Callable, List, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
|
|
||||||
@ -39,19 +37,26 @@ def get_start(segments: List[dict]) -> Optional[float]:
|
|||||||
class ResultWriter:
|
class ResultWriter:
|
||||||
extension: str
|
extension: str
|
||||||
|
|
||||||
def __init__(self, output_dir: str):
|
def __init__(self, output_dir: str, output_name_template: str):
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
self.output_name_template = output_name_template
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
if isinstance(audio_obj, (str, pathlib.Path)):
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
basename = pathlib.Path(audio_obj).stem
|
||||||
output_path = os.path.join(
|
else:
|
||||||
self.output_dir, audio_basename + "." + self.extension
|
# mx.array, np.ndarray, etc
|
||||||
|
basename = "content"
|
||||||
|
|
||||||
|
output_basename = self.output_name_template.format(basename=basename)
|
||||||
|
|
||||||
|
output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix(
|
||||||
|
f".{self.extension}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with output_path.open("wt", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options, **kwargs)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
@ -248,7 +253,7 @@ class WriteJSON(ResultWriter):
|
|||||||
|
|
||||||
|
|
||||||
def get_writer(
|
def get_writer(
|
||||||
output_format: str, output_dir: str
|
output_format: str, output_dir: str, output_name_template: str
|
||||||
) -> Callable[[dict, TextIO, dict], None]:
|
) -> Callable[[dict, TextIO, dict], None]:
|
||||||
writers = {
|
writers = {
|
||||||
"txt": WriteTXT,
|
"txt": WriteTXT,
|
||||||
@ -259,7 +264,9 @@ def get_writer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if output_format == "all":
|
if output_format == "all":
|
||||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
all_writers = [
|
||||||
|
writer(output_dir, output_name_template) for writer in writers.values()
|
||||||
|
]
|
||||||
|
|
||||||
def write_all(
|
def write_all(
|
||||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
@ -269,4 +276,4 @@ def get_writer(
|
|||||||
|
|
||||||
return write_all
|
return write_all
|
||||||
|
|
||||||
return writers[output_format](output_dir)
|
return writers[output_format](output_dir, output_name_template)
|
||||||
|
@ -3,40 +3,21 @@
|
|||||||
set -o err_exit
|
set -o err_exit
|
||||||
|
|
||||||
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
|
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
# when not receiving stdin, check audio arg is required
|
|
||||||
TEST_1="mlx_whisper requires audio position arg when not provided with stdin"
|
|
||||||
if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then
|
|
||||||
echo "[PASS] $TEST_1"
|
|
||||||
else
|
|
||||||
echo "[FAIL] $TEST_1"
|
|
||||||
fi
|
|
||||||
|
|
||||||
TEST_2="mlx_whisper does not require audio position arg when provided with stdin"
|
|
||||||
if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then
|
|
||||||
echo "[PASS] $TEST_2"
|
|
||||||
else
|
|
||||||
echo "[FAIL] $TEST_2"
|
|
||||||
fi
|
|
||||||
|
|
||||||
TEST_3="mlx_whisper accepts optional --input-name arg"
|
|
||||||
if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then
|
|
||||||
echo "[PASS] $TEST_3"
|
|
||||||
else
|
|
||||||
echo "[FAIL] $TEST_3"
|
|
||||||
fi
|
|
||||||
|
|
||||||
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
|
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
|
||||||
|
|
||||||
# the control output - cli called with audio position arg
|
# the control output - cli called with audio position arg
|
||||||
# expected output file name is ls_test.json
|
# expected output file name is ls_test.json
|
||||||
mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False
|
mlx_whisper "$TEST_AUDIO" \
|
||||||
|
--output-dir "$TEST_OUTPUT_DIR" \
|
||||||
|
--output-format all \
|
||||||
|
--output-name '{basename}_transcribed' \
|
||||||
|
--temperature 0 \
|
||||||
|
--verbose=False
|
||||||
|
/bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} | sort
|
||||||
|
|
||||||
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
|
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
|
||||||
# method stdin - output file is content.json (default --input-name is content when not provided)
|
/bin/cat "$TEST_AUDIO" | mlx_whisper - --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
|
||||||
/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
|
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then
|
||||||
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then
|
|
||||||
echo "[PASS] $TEST_STDIN_1"
|
echo "[PASS] $TEST_STDIN_1"
|
||||||
else
|
else
|
||||||
echo "[FAIL] $TEST_STDIN_1"
|
echo "[FAIL] $TEST_STDIN_1"
|
||||||
@ -44,10 +25,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"、
|
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"、
|
||||||
mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
|
mlx_whisper - --output-name '{basename}_transcribed' --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
|
||||||
if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then
|
if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then
|
||||||
echo "[PASS] $TEST_STDIN_2"
|
echo "[PASS] $TEST_STDIN_2"
|
||||||
else
|
else
|
||||||
echo "[FAIL] $TEST_STDIN_2"
|
echo "[FAIL] $TEST_STDIN_2"
|
||||||
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
|
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"
|
||||||
|
Loading…
Reference in New Issue
Block a user