some nits

This commit is contained in:
Awni Hannun 2024-11-04 13:59:53 -08:00
parent 266f99a1e7
commit bd6b08e813
5 changed files with 24 additions and 102 deletions

View File

@ -26,7 +26,7 @@ pip install mlx-whisper
At its simplest: At its simplest:
```sh ```sh
mlx_whisper audio_file.mp3 # output name will re-use basename of audio file path mlx_whisper audio_file.mp3
``` ```
This will make a text file `audio_file.txt` with the results. This will make a text file `audio_file.txt` with the results.
@ -35,19 +35,14 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run are many other supported command line options. To see them all, run
`mlx_whisper -h`. `mlx_whisper -h`.
Alternatively, you can pipe in the audio content of other programs via stdin, You can also pipe the audio content of other programs via stdin:
useful when `mlx_whisper` acts as a composable command line utility.
```sh ```sh
# hypothetical demo of audio content via stdin some-process | mlx_whisper -
# default output file name will be content.*
some-process | mlx_whisper
# hypothetical demo of media content via stdin
# use --output-name to name your output artifacts
some-downloader https://some.url/media?id=lecture42 | mlx_whisper --output-name mlx-demo
``` ```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API #### API

View File

@ -2,6 +2,7 @@
import argparse import argparse
import os import os
import pathlib
import traceback import traceback
import warnings import warnings
@ -40,8 +41,11 @@ def build_parser():
parser.add_argument( parser.add_argument(
"--output-name", "--output-name",
type=str, type=str,
default="{basename}", default=None,
help="logical name of transcription/translation output files, before --output-format extensions", help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
) )
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
@ -207,10 +211,10 @@ def main():
path_or_hf_repo: str = args.pop("model") path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
output_name_template: str = args.pop("output_name") output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir, output_name_template) writer = get_writer(output_format, output_dir)
word_options = [ word_options = [
"highlight_words", "highlight_words",
"max_line_count", "max_line_count",
@ -233,13 +237,16 @@ def main():
# receive the contents from stdin rather than read a file # receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True) audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try: try:
result = transcribe( result = transcribe(
audio_obj, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
writer(result, audio_obj, **writer_args) writer(result, output_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")

View File

@ -37,22 +37,13 @@ def get_start(segments: List[dict]) -> Optional[float]:
class ResultWriter: class ResultWriter:
extension: str extension: str
def __init__(self, output_dir: str, output_name_template: str): def __init__(self, output_dir: str):
self.output_dir = output_dir self.output_dir = output_dir
self.output_name_template = output_name_template
def __call__( def __call__(
self, result: dict, audio_obj: str, options: Optional[dict] = None, **kwargs self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
): ):
if isinstance(audio_obj, (str, pathlib.Path)): output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
basename = pathlib.Path(audio_obj).stem
else:
# mx.array, np.ndarray, etc
basename = "content"
output_basename = self.output_name_template.format(basename=basename)
output_path = (pathlib.Path(self.output_dir) / output_basename).with_suffix(
f".{self.extension}" f".{self.extension}"
) )
@ -253,7 +244,7 @@ class WriteJSON(ResultWriter):
def get_writer( def get_writer(
output_format: str, output_dir: str, output_name_template: str output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]: ) -> Callable[[dict, TextIO, dict], None]:
writers = { writers = {
"txt": WriteTXT, "txt": WriteTXT,
@ -264,9 +255,7 @@ def get_writer(
} }
if output_format == "all": if output_format == "all":
all_writers = [ all_writers = [writer(output_dir) for writer in writers.values()]
writer(output_dir, output_name_template) for writer in writers.values()
]
def write_all( def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
@ -276,4 +265,4 @@ def get_writer(
return write_all return write_all
return writers[output_format](output_dir, output_name_template) return writers[output_format](output_dir)

View File

@ -1,69 +0,0 @@
#!/bin/zsh -e
set -o err_exit
TEST_AUDIO="mlx_whisper/assets/ls_test.flac"
TEST_OUTPUT_DIR=$(mktemp -d -t mlx_whisper_cli_test)
# the control output - cli called with audio position arg
# expected output file name is ls_test.json
TEST_OUTPUT_NAME_FOR_ALL="--output-name arg is used for all output formats"
mlx_whisper "$TEST_AUDIO" \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format all \
--output-name '{basename}_transcribed' \
--temperature 0 \
--verbose=False
if /bin/ls ${TEST_OUTPUT_DIR}/ls_test_transcribed.{json,srt,tsv,txt,vtt} > /dev/null; then
echo "[PASS] $TEST_OUTPUT_NAME_FOR_ALL"
else
echo "[FAIL] $TEST_OUTPUT_NAME_FOR_ALL"
fi
TEST_OUTPUT_NAME_TEMPLATE="testing the output name template usage scenario"
for test_val in $(seq 10 10 60); do
mlx_whisper "$TEST_AUDIO" \
--output-name "{basename}_mwpl_${test_val}" \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format srt \
--max-words-per-line $test_val \
--word-timestamps True \
--verbose=False
TEST_DESC="testing output name template while varying --max-words-per-line=${test_val}"
if /bin/ls $TEST_OUTPUT_DIR/ls_test_mwpl_${test_val}.srt > /dev/null; then
echo "[PASS] $TEST_DESC"
else
echo "[FAIL] $TEST_DESC"
fi
done
TEST_STDIN_1="mlx_whisper produces identical output whether provided audio arg or stdin of same content"
/bin/cat "$TEST_AUDIO" | mlx_whisper - \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format json \
--temperature 0 \
--verbose=False
if diff "${TEST_OUTPUT_DIR}/content.json" "${TEST_OUTPUT_DIR}/ls_test_transcribed.json"; then
echo "[PASS] $TEST_STDIN_1"
else
echo "[FAIL] $TEST_STDIN_1"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi
TEST_STDIN_2="mlx_whisper produces identical output when stdin comes via: cmd < file"
mlx_whisper - \
--output-name '{basename}_transcribed' \
--output-dir "$TEST_OUTPUT_DIR" \
--output-format tsv \
--temperature 0 \
--verbose=False < "$TEST_AUDIO"
if diff "${TEST_OUTPUT_DIR}/content_transcribed.tsv" "${TEST_OUTPUT_DIR}/ls_test_transcribed.tsv"; then
echo "[PASS] $TEST_STDIN_2"
else
echo "[FAIL] $TEST_STDIN_2"
echo "Check unexpected output in ${TEST_OUTPUT_DIR}"
fi
echo "Outputs can be verified in ${TEST_OUTPUT_DIR}"