mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 08:43:26 +08:00

Add support for transcribing all files in a directory recursively. The implementation lets ffmpeg handle file validation instead of filtering by extension. Update README with minimal documentation for directory support.
297 lines
10 KiB
Python
297 lines
10 KiB
Python
# Copyright © 2024 Apple Inc.
|
||
|
||
import argparse
|
||
import os
|
||
import pathlib
|
||
import traceback
|
||
import warnings
|
||
|
||
from . import audio
|
||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||
from .transcribe import transcribe
|
||
from .writers import get_writer
|
||
|
||
|
||
def get_media_files(path):
|
||
"""Get all files in a directory recursively, excluding hidden files."""
|
||
path = pathlib.Path(path)
|
||
if not path.is_dir():
|
||
return []
|
||
|
||
return [
|
||
file_path for file_path in path.rglob("*")
|
||
if file_path.is_file() and not any(p.startswith(".") for p in file_path.parts)
|
||
]
|
||
|
||
|
||
def build_parser():
|
||
def optional_int(string):
|
||
return None if string == "None" else int(string)
|
||
|
||
def optional_float(string):
|
||
return None if string == "None" else float(string)
|
||
|
||
def str2bool(string):
|
||
str2val = {"True": True, "False": False}
|
||
if string in str2val:
|
||
return str2val[string]
|
||
else:
|
||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||
|
||
parser = argparse.ArgumentParser(
|
||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||
)
|
||
|
||
parser.add_argument("audio", nargs="+", type=str, help="Path(s) to audio file(s) or directories to transcribe")
|
||
|
||
parser.add_argument(
|
||
"--model",
|
||
default="mlx-community/whisper-tiny",
|
||
type=str,
|
||
help="The model directory or hugging face repo",
|
||
)
|
||
parser.add_argument(
|
||
"--output-name",
|
||
type=str,
|
||
default=None,
|
||
help=(
|
||
"The name of transcription/translation output files before "
|
||
"--output-format extensions"
|
||
),
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
"-o",
|
||
type=str,
|
||
default=".",
|
||
help="Directory to save the outputs",
|
||
)
|
||
parser.add_argument(
|
||
"--output-format",
|
||
"-f",
|
||
type=str,
|
||
default="txt",
|
||
choices=["txt", "vtt", "srt", "tsv", "json", "all"],
|
||
help="Format of the output file",
|
||
)
|
||
parser.add_argument(
|
||
"--verbose",
|
||
type=str2bool,
|
||
default=True,
|
||
help="Whether to print out progress and debug messages",
|
||
)
|
||
parser.add_argument(
|
||
"--task",
|
||
type=str,
|
||
default="transcribe",
|
||
choices=["transcribe", "translate"],
|
||
help="Perform speech recognition ('transcribe') or speech translation ('translate')",
|
||
)
|
||
parser.add_argument(
|
||
"--language",
|
||
type=str,
|
||
default=None,
|
||
choices=sorted(LANGUAGES.keys())
|
||
+ sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||
help="Language spoken in the audio, specify None to auto-detect",
|
||
)
|
||
parser.add_argument(
|
||
"--temperature", type=float, default=0, help="Temperature for sampling"
|
||
)
|
||
parser.add_argument(
|
||
"--best-of",
|
||
type=optional_int,
|
||
default=5,
|
||
help="Number of candidates when sampling with non-zero temperature",
|
||
)
|
||
parser.add_argument(
|
||
"--patience",
|
||
type=float,
|
||
default=None,
|
||
help="Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search",
|
||
)
|
||
parser.add_argument(
|
||
"--length-penalty",
|
||
type=float,
|
||
default=None,
|
||
help="Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.",
|
||
)
|
||
parser.add_argument(
|
||
"--suppress-tokens",
|
||
type=str,
|
||
default="-1",
|
||
help="Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations",
|
||
)
|
||
parser.add_argument(
|
||
"--initial-prompt",
|
||
type=str,
|
||
default=None,
|
||
help="Optional text to provide as a prompt for the first window.",
|
||
)
|
||
parser.add_argument(
|
||
"--condition-on-previous-text",
|
||
type=str2bool,
|
||
default=True,
|
||
help="If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop",
|
||
)
|
||
parser.add_argument(
|
||
"--fp16",
|
||
type=str2bool,
|
||
default=True,
|
||
help="Whether to perform inference in fp16",
|
||
)
|
||
parser.add_argument(
|
||
"--compression-ratio-threshold",
|
||
type=optional_float,
|
||
default=2.4,
|
||
help="if the gzip compression ratio is higher than this value, treat the decoding as failed",
|
||
)
|
||
parser.add_argument(
|
||
"--logprob-threshold",
|
||
type=optional_float,
|
||
default=-1.0,
|
||
help="If the average log probability is lower than this value, treat the decoding as failed",
|
||
)
|
||
parser.add_argument(
|
||
"--no-speech-threshold",
|
||
type=optional_float,
|
||
default=0.6,
|
||
help="If the probability of the token is higher than this value the decoding has failed due to `logprob_threshold`, consider the segment as silence",
|
||
)
|
||
parser.add_argument(
|
||
"--word-timestamps",
|
||
type=str2bool,
|
||
default=False,
|
||
help="Extract word-level timestamps and refine the results based on them",
|
||
)
|
||
parser.add_argument(
|
||
"--prepend-punctuations",
|
||
type=str,
|
||
default="\"'“¿([{-",
|
||
help="If word-timestamps is True, merge these punctuation symbols with the next word",
|
||
)
|
||
parser.add_argument(
|
||
"--append-punctuations",
|
||
type=str,
|
||
default="\"'.。,,!!??::”)]}、",
|
||
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
|
||
)
|
||
parser.add_argument(
|
||
"--highlight-words",
|
||
type=str2bool,
|
||
default=False,
|
||
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
|
||
)
|
||
parser.add_argument(
|
||
"--max-line-width",
|
||
type=int,
|
||
default=None,
|
||
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
|
||
)
|
||
parser.add_argument(
|
||
"--max-line-count",
|
||
type=int,
|
||
default=None,
|
||
help="(requires --word_timestamps True) the maximum number of lines in a segment",
|
||
)
|
||
parser.add_argument(
|
||
"--max-words-per-line",
|
||
type=int,
|
||
default=None,
|
||
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
|
||
)
|
||
parser.add_argument(
|
||
"--hallucination-silence-threshold",
|
||
type=optional_float,
|
||
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
|
||
)
|
||
parser.add_argument(
|
||
"--clip-timestamps",
|
||
type=str,
|
||
default="0",
|
||
help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
|
||
)
|
||
return parser
|
||
|
||
|
||
def main():
|
||
parser = build_parser()
|
||
args = vars(parser.parse_args())
|
||
if args["verbose"] is True:
|
||
print(f"Args: {args}")
|
||
|
||
path_or_hf_repo: str = args.pop("model")
|
||
output_dir: str = args.pop("output_dir")
|
||
output_format: str = args.pop("output_format")
|
||
output_name: str = args.pop("output_name")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
writer = get_writer(output_format, output_dir)
|
||
word_options = [
|
||
"highlight_words",
|
||
"max_line_count",
|
||
"max_line_width",
|
||
"max_words_per_line",
|
||
]
|
||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||
if not args["word_timestamps"]:
|
||
for k, v in writer_args.items():
|
||
if v:
|
||
argop = k.replace("_", "-")
|
||
parser.error(f"--{argop} requires --word-timestamps True")
|
||
if writer_args["max_line_count"] and not writer_args["max_line_width"]:
|
||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||
|
||
for audio_obj in args.pop("audio"):
|
||
if audio_obj == "-":
|
||
# receive the contents from stdin rather than read a file
|
||
audio_obj = audio.load_audio(from_stdin=True)
|
||
|
||
output_name = output_name or "content"
|
||
else:
|
||
path = pathlib.Path(audio_obj)
|
||
if path.is_dir():
|
||
media_files = get_media_files(path)
|
||
if not media_files:
|
||
print(f"No media files found in directory: {path}")
|
||
continue
|
||
|
||
print(f"Found {len(media_files)} files in directory{path}")
|
||
response = input("Continue processing all files in target directory? [Y/n] ").strip()
|
||
if response.lower() in ['n', 'no']:
|
||
continue
|
||
|
||
if args.get("verbose"):
|
||
print(f"Processing {len(media_files)} files in {path}...")
|
||
|
||
for file_path in media_files:
|
||
try:
|
||
result = transcribe(
|
||
str(file_path),
|
||
path_or_hf_repo=path_or_hf_repo,
|
||
**args,
|
||
)
|
||
writer(result, file_path.stem, **writer_args)
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
print(f"Skipping {file_path} due to {type(e).__name__}: {str(e)}")
|
||
continue
|
||
output_name = output_name or path.stem
|
||
|
||
try:
|
||
result = transcribe(
|
||
audio_obj,
|
||
path_or_hf_repo=path_or_hf_repo,
|
||
**args,
|
||
)
|
||
writer(result, output_name, **writer_args)
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|