add support for audio and input name from stdin

This commit is contained in:
Anthony Wu 2024-10-03 00:34:47 -10:00
parent 9bc53fc210
commit bb5d7db5d7
4 changed files with 108 additions and 19 deletions

View File

@ -25,8 +25,8 @@ pip install mlx-whisper
At its simplest: At its simplest:
``` ```sh
mlx_whisper audio_file.mp3 mlx_whisper audio_file.mp3 # output name will re-use basename of audio file path
``` ```
This will make a text file `audio_file.txt` with the results. This will make a text file `audio_file.txt` with the results.
@ -35,6 +35,20 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run are many other supported command line options. To see them all, run
`mlx_whisper -h`. `mlx_whisper -h`.
Alternatively, you can pipe in the audio content of other programs via stdin,
useful when `mlx_whisper` acts as a composable command line utility.
```sh
# hypothetical demo of audio content via stdin
# default output file name will be content.*
some-process | mlx_whisper
# hypothetical demo of media content via stdin
# use --input-name to name your output artifacts
some-downloader https://some.url/media?id=lecture42 | mlx_whisper --input-name mlx-demo
```
#### API #### API
Transcribe audio with: Transcribe audio with:
@ -103,7 +117,7 @@ python convert.py --help
``` ```
By default, the conversion script will make the directory `mlx_models` By default, the conversion script will make the directory `mlx_models`
and save the converted `weights.npz` and `config.json` there. and save the converted `weights.npz` and `config.json` there.
Each time it is run, `convert.py` will overwrite any model in the provided Each time it is run, `convert.py` will overwrite any model in the provided
path. To save different models, make sure to set `--mlx-path` to a unique path. To save different models, make sure to set `--mlx-path` to a unique

View File

@ -3,7 +3,7 @@
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary
@ -40,18 +40,20 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
# This launches a subprocess to decode audio while down-mixing # This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH. # and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off # fmt: off
cmd = [ cmd.extend([
"ffmpeg",
"-nostdin",
"-threads", "0", "-threads", "0",
"-i", file,
"-f", "s16le", "-f", "s16le",
"-ac", "1", "-ac", "1",
"-acodec", "pcm_s16le", "-acodec", "pcm_s16le",
"-ar", str(sr), "-ar", str(sr),
"-" "-"
] ])
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,15 +2,17 @@
import argparse import argparse
import os import os
import sys
import traceback import traceback
import warnings import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe from .transcribe import transcribe
from .writers import get_writer from .writers import get_writer
def build_parser(): def build_parser(is_audio_from_stdin=False):
def optional_int(string): def optional_int(string):
return None if string == "None" else int(string) return None if string == "None" else int(string)
@ -27,15 +29,22 @@ def build_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe" if not is_audio_from_stdin:
) parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx-community/whisper-tiny", default="mlx-community/whisper-tiny",
type=str, type=str,
help="The model directory or hugging face repo", help="The model directory or hugging face repo",
) )
parser.add_argument(
"--input-name",
type=str,
default="content",
help="logical name of audio content received via stdin",
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
"-o", "-o",
@ -192,7 +201,8 @@ def build_parser():
def main(): def main():
parser = build_parser() is_audio_from_stdin = not os.isatty(sys.stdin.fileno())
parser = build_parser(is_audio_from_stdin=is_audio_from_stdin)
args = vars(parser.parse_args()) args = vars(parser.parse_args())
if args["verbose"] is True: if args["verbose"] is True:
print(f"Args: {args}") print(f"Args: {args}")
@ -219,17 +229,27 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width") warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]: if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width") warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
if is_audio_from_stdin:
audio_list = [audio.load_audio(from_stdin=True)]
input_name = args.pop("input_name")
else:
audio_list = args.pop("audio")
args.pop("input_name")
for audio_obj in audio_list:
try: try:
result = transcribe( result = transcribe(
audio_path, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
writer(result, audio_path, **writer_args) if not is_audio_from_stdin:
input_name = audio_obj
writer(result, input_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

53
whisper/test_cli.sh Executable file
View File

@ -0,0 +1,53 @@
#!/bin/zsh -e
set -o err_exit
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
# when not receiving stdin, check audio arg is required
TEST_1="mlx_whisper requires audio position arg when not provided with stdin"
if mlx_whisper 2>&1 | grep "the following arguments are required: audio" > /dev/null; then
echo "[PASS] $TEST_1"
else
echo "[FAIL] $TEST_1"
fi
TEST_2="mlx_whisper does not require audio position arg when provided with stdin"
if ! (/bin/cat "$TEST_AUDIO" | mlx_whisper --help | /usr/bin/grep "Audio file(s) to transcribe") > /dev/null; then
echo "[PASS] $TEST_2"
else
echo "[FAIL] $TEST_2"
fi
TEST_3="mlx_whisper accepts optional --input-name arg"
if (mlx_whisper --help | /usr/bin/grep "\-\-input-name") > /dev/null; then
echo "[PASS] $TEST_3"
else
echo "[FAIL] $TEST_3"
fi
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
# the control output - cli called with audio position arg
# expected output file name is ls_test.json
mlx_whisper "$TEST_AUDIO" --output-dir "$TEST_OUTPUT_DIR" --output-format all --temperature 0 --verbose=False
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
# method stdin - output file is content.json (default --input-name is content when not provided)
/bin/cat "$TEST_AUDIO" | mlx_whisper --output-dir "$TEST_OUTPUT_DIR" --output-format json --temperature 0 --verbose=False
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test.json"; then
echo "[PASS] $TEST_STDIN_1"
else
echo "[FAIL] $TEST_STDIN_1"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"
mlx_whisper --input-name stdin_test_2 --output-dir "$TEST_OUTPUT_DIR" --output-format tsv --temperature 0 --verbose=False < "$TEST_AUDIO"
if diff "${TEST_OUTPUT_DIR}/stdin_test_2.tsv" "${TEST_OUTPUT_DIR}/ls_test.tsv"; then
echo "[PASS] $TEST_STDIN_2"
else
echo "[FAIL] $TEST_STDIN_2"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi