add support for audio and input name from stdin

This commit is contained in:
Anthony Wu
2024-10-03 00:34:47 -10:00
parent 9bc53fc210
commit bb5d7db5d7
4 changed files with 108 additions and 19 deletions

View File

@@ -3,7 +3,7 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Union
from typing import Optional, Union
import mlx.core as mx
import numpy as np
@@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
"""
Open an audio file and read as mono waveform, resampling as necessary
@@ -40,18 +40,20 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
cmd.extend([
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
])
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout

View File

@@ -2,15 +2,17 @@
import argparse
import os
import sys
import traceback
import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
def build_parser():
def build_parser(is_audio_from_stdin=False):
def optional_int(string):
return None if string == "None" else int(string)
@@ -27,15 +29,22 @@ def build_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
)
if not is_audio_from_stdin:
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
parser.add_argument(
"--input-name",
type=str,
default="content",
help="logical name of audio content received via stdin",
)
parser.add_argument(
"--output-dir",
"-o",
@@ -192,7 +201,8 @@ def build_parser():
def main():
parser = build_parser()
is_audio_from_stdin = not os.isatty(sys.stdin.fileno())
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
args = vars(parser.parse_args())
if args["verbose"] is True:
print(f"Args: {args}")
@@ -219,17 +229,27 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
if is_audio_from_stdin:
audio_list = [audio.load_audio(from_stdin=True)]
input_name = args.pop("input_name")
else:
audio_list = args.pop("audio")
args.pop("input_name")
for audio_obj in audio_list:
try:
result = transcribe(
audio_path,
audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
writer(result, audio_path, **writer_args)
if not is_audio_from_stdin:
input_name = audio_obj
writer(result, input_name, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":