mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
mlx_whisper: add support for audio input from stdin (#1012)
* add support for audio and input name from stdin * refactored to stdin - arg, and output-name template * fix bugs, add test coverage * fix doc to match arg rename * some nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
3b526f0aa1
commit
4394633ce0
@ -25,7 +25,7 @@ pip install mlx-whisper
|
||||
|
||||
At its simplest:
|
||||
|
||||
```
|
||||
```sh
|
||||
mlx_whisper audio_file.mp3
|
||||
```
|
||||
|
||||
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
|
||||
are many other supported command line options. To see them all, run
|
||||
`mlx_whisper -h`.
|
||||
|
||||
You can also pipe the audio content of other programs via stdin:
|
||||
|
||||
```sh
|
||||
some-process | mlx_whisper -
|
||||
```
|
||||
|
||||
The default output file name will be `content.*`. You can specify the name with
|
||||
the `--output-name` flag.
|
||||
|
||||
#### API
|
||||
|
||||
Transcribe audio with:
|
||||
@ -103,7 +112,7 @@ python convert.py --help
|
||||
```
|
||||
|
||||
By default, the conversion script will make the directory `mlx_models`
|
||||
and save the converted `weights.npz` and `config.json` there.
|
||||
and save the converted `weights.npz` and `config.json` there.
|
||||
|
||||
Each time it is run, `convert.py` will overwrite any model in the provided
|
||||
path. To save different models, make sure to set `--mlx-path` to a unique
|
||||
|
@ -3,7 +3,7 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from subprocess import CalledProcessError, run
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
if from_stdin:
|
||||
cmd = ["ffmpeg", "-i", "pipe:0"]
|
||||
else:
|
||||
cmd = ["ffmpeg", "-nostdin", "-i", file]
|
||||
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
cmd.extend([
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
])
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from . import audio
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||
from .transcribe import transcribe
|
||||
from .writers import get_writer
|
||||
@ -27,15 +29,24 @@ def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
|
||||
)
|
||||
|
||||
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx-community/whisper-tiny",
|
||||
type=str,
|
||||
help="The model directory or hugging face repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of transcription/translation output files before "
|
||||
"--output-format extensions"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"-o",
|
||||
@ -200,6 +211,7 @@ def main():
|
||||
path_or_hf_repo: str = args.pop("model")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
output_name: str = args.pop("output_name")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
@ -219,17 +231,25 @@ def main():
|
||||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||
for audio_path in args.pop("audio"):
|
||||
|
||||
for audio_obj in args.pop("audio"):
|
||||
if audio_obj == "-":
|
||||
# receive the contents from stdin rather than read a file
|
||||
audio_obj = audio.load_audio(from_stdin=True)
|
||||
|
||||
output_name = output_name or "content"
|
||||
else:
|
||||
output_name = output_name or pathlib.Path(audio_obj).stem
|
||||
try:
|
||||
result = transcribe(
|
||||
audio_path,
|
||||
audio_obj,
|
||||
path_or_hf_repo=path_or_hf_repo,
|
||||
**args,
|
||||
)
|
||||
writer(result, audio_path, **writer_args)
|
||||
writer(result, output_name, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,8 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
|
||||
@ -43,15 +41,13 @@ class ResultWriter:
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
|
||||
f".{self.extension}"
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
with output_path.open("wt", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(
|
||||
|
Loading…
Reference in New Issue
Block a user