mlx_whisper: add support for audio input from stdin (#1012)

* add support for audio and input name from stdin

* refactored to stdin - arg, and output-name template

* fix bugs, add test coverage

* fix doc to match arg rename

* some nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anthony Wu 2024-11-04 14:02:13 -08:00 committed by GitHub
parent 3b526f0aa1
commit 4394633ce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 26 deletions

View File

@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest: At its simplest:
``` ```sh
mlx_whisper audio_file.mp3 mlx_whisper audio_file.mp3
``` ```
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run are many other supported command line options. To see them all, run
`mlx_whisper -h`. `mlx_whisper -h`.
You can also pipe the audio content of other programs via stdin:
```sh
some-process | mlx_whisper -
```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API #### API
Transcribe audio with: Transcribe audio with:

View File

@ -3,7 +3,7 @@
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
# This launches a subprocess to decode audio while down-mixing # This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH. # and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off # fmt: off
cmd = [ cmd.extend([
"ffmpeg",
"-nostdin",
"-threads", "0", "-threads", "0",
"-i", file,
"-f", "s16le", "-f", "s16le",
"-ac", "1", "-ac", "1",
"-acodec", "pcm_s16le", "-acodec", "pcm_s16le",
"-ar", str(sr), "-ar", str(sr),
"-" "-"
] ])
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,9 +2,11 @@
import argparse import argparse
import os import os
import pathlib
import traceback import traceback
import warnings import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe from .transcribe import transcribe
from .writers import get_writer from .writers import get_writer
@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe" parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
)
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx-community/whisper-tiny", default="mlx-community/whisper-tiny",
type=str, type=str,
help="The model directory or hugging face repo", help="The model directory or hugging face repo",
) )
parser.add_argument(
"--output-name",
type=str,
default=None,
help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
"-o", "-o",
@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model") path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width") warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]: if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width") warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try: try:
result = transcribe( result = transcribe(
audio_path, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
writer(result, audio_path, **writer_args) writer(result, output_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import json import json
import os import pathlib
import re import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO from typing import Callable, List, Optional, TextIO
@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir self.output_dir = output_dir
def __call__( def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
): ):
audio_basename = os.path.basename(audio_path) output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
audio_basename = os.path.splitext(audio_basename)[0] f".{self.extension}"
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
) )
with open(output_path, "w", encoding="utf-8") as f: with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs) self.write_result(result, file=f, options=options, **kwargs)
def write_result( def write_result(