From 6850fce80ec91a99296e37edbc8c78e13cb5016f Mon Sep 17 00:00:00 2001 From: Cavit Erginsoy Date: Mon, 20 Jan 2025 21:59:11 +0000 Subject: [PATCH] whisper: add directory transcription support Add support for transcribing all files in a directory recursively. The implementation lets ffmpeg handle file validation instead of filtering by extension. Update README with minimal documentation for directory support. --- whisper/README.md | 5 +++++ whisper/mlx_whisper/cli.py | 44 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/whisper/README.md b/whisper/README.md index cd3bc684..5b6cb4a2 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -31,6 +31,11 @@ mlx_whisper audio_file.mp3 This will make a text file `audio_file.txt` with the results. +You can also transcribe a directory of audio files: +```sh +mlx_whisper path/to/directory/ +``` + Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index 7d08a043..e40466d3 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -12,6 +12,18 @@ from .transcribe import transcribe from .writers import get_writer +def get_media_files(path): + """Get all files in a directory recursively, excluding hidden files.""" + path = pathlib.Path(path) + if not path.is_dir(): + return [] + + return [ + file_path for file_path in path.rglob("*") + if file_path.is_file() and not any(p.startswith(".") for p in file_path.parts) + ] + + def build_parser(): def optional_int(string): return None if string == "None" else int(string) @@ -30,7 +42,7 @@ def build_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument("audio", nargs="+", type=str, help="Path(s) to audio file(s) or directories to transcribe") parser.add_argument( "--model", @@ -239,7 +251,35 @@ def main(): output_name = output_name or "content" else: - output_name = output_name or pathlib.Path(audio_obj).stem + path = pathlib.Path(audio_obj) + if path.is_dir(): + media_files = get_media_files(path) + if not media_files: + print(f"No media files found in directory: {path}") + continue + + print(f"Found {len(media_files)} files in directory{path}") + response = input("Continue processing all files in target directory? [Y/n] ").strip() + if response.lower() in ['n', 'no']: + continue + + if args.get("verbose"): + print(f"Processing {len(media_files)} files in {path}...") + + for file_path in media_files: + try: + result = transcribe( + str(file_path), + path_or_hf_repo=path_or_hf_repo, + **args, + ) + writer(result, file_path.stem, **writer_args) + except Exception as e: + traceback.print_exc() + print(f"Skipping {file_path} due to {type(e).__name__}: {str(e)}") + continue + output_name = output_name or path.stem + try: result = transcribe( audio_obj,