mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
mlx_whisper: add support for audio input from stdin (#1012)
* add support for audio and input name from stdin * refactored to stdin - arg, and output-name template * fix bugs, add test coverage * fix doc to match arg rename * some nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2,9 +2,11 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from . import audio
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
from .transcribe import transcribe
|
||||
from .writers import get_writer
|
||||
@@ -27,15 +29,24 @@ def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
|
||||
)
|
||||
|
||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/whisper-tiny",
|
||||
type=str,
|
||||
help="The model directory or hugging face repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of transcription/translation output files before "
|
||||
"--output-format extensions"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
@@ -200,6 +211,7 @@ def main():
|
||||
path_or_hf_repo: str = args.pop("model")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
output_name: str = args.pop("output_name")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@@ -219,17 +231,25 @@ def main():
|
||||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||
for audio_path in args.pop("audio"):
|
||||
|
||||
for audio_obj in args.pop("audio"):
|
||||
if audio_obj == "-":
|
||||
# receive the contents from stdin rather than read a file
|
||||
audio_obj = audio.load_audio(from_stdin=True)
|
||||
|
||||
output_name = output_name or "content"
|
||||
else:
|
||||
output_name = output_name or pathlib.Path(audio_obj).stem
|
||||
try:
|
||||
result = transcribe(
|
||||
audio_path,
|
||||
audio_obj,
|
||||
path_or_hf_repo=path_or_hf_repo,
|
||||
**args,
|
||||
)
|
||||
writer(result, audio_path, **writer_args)
|
||||
writer(result, output_name, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user