From 6775d6cb3f9426c13219ccd0dd854bbe899630a9 Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 2 May 2024 00:00:02 +0800 Subject: [PATCH] Whisper: Add pip distribution configuration to support pip installations. (#739) * Whisper: rename whisper to mlx_whisper * Whisper: add setup.py config for publish * Whisper: add assets data to setup config * Whisper: pre-commit for setup.py * Whisper: Update README.md * Whisper: Update README.md * nits * fix package data * nit in readme --------- Co-authored-by: Awni Hannun --- llms/setup.py | 2 + whisper/MANIFEST.in | 4 + whisper/README.md | 99 ++++++++++-------- whisper/benchmark.py | 19 +--- whisper/convert.py | 7 +- whisper/{whisper => mlx_whisper}/__init__.py | 3 +- .../assets/download_alice.sh | 0 .../assets/gpt2.tiktoken | 0 .../assets/ls_test.flac | Bin .../assets/mel_filters.npz | Bin .../assets/multilingual.tiktoken | 0 whisper/{whisper => mlx_whisper}/audio.py | 0 whisper/{whisper => mlx_whisper}/decoding.py | 0 .../{whisper => mlx_whisper}/load_models.py | 0 whisper/{ => mlx_whisper}/requirements.txt | 2 +- whisper/{whisper => mlx_whisper}/timing.py | 0 whisper/{whisper => mlx_whisper}/tokenizer.py | 0 .../{whisper => mlx_whisper}/torch_whisper.py | 0 .../{whisper => mlx_whisper}/transcribe.py | 0 whisper/mlx_whisper/version.py | 3 + whisper/{whisper => mlx_whisper}/whisper.py | 0 whisper/setup.py | 32 ++++++ whisper/test.py | 19 ++-- 23 files changed, 116 insertions(+), 74 deletions(-) create mode 100644 whisper/MANIFEST.in rename whisper/{whisper => mlx_whisper}/__init__.py (53%) rename whisper/{whisper => mlx_whisper}/assets/download_alice.sh (100%) rename whisper/{whisper => mlx_whisper}/assets/gpt2.tiktoken (100%) rename whisper/{whisper => mlx_whisper}/assets/ls_test.flac (100%) rename whisper/{whisper => mlx_whisper}/assets/mel_filters.npz (100%) rename whisper/{whisper => mlx_whisper}/assets/multilingual.tiktoken (100%) rename whisper/{whisper => mlx_whisper}/audio.py (100%) rename whisper/{whisper => mlx_whisper}/decoding.py (100%) rename whisper/{whisper => mlx_whisper}/load_models.py (100%) rename whisper/{ => mlx_whisper}/requirements.txt (81%) rename whisper/{whisper => mlx_whisper}/timing.py (100%) rename whisper/{whisper => mlx_whisper}/tokenizer.py (100%) rename whisper/{whisper => mlx_whisper}/torch_whisper.py (100%) rename whisper/{whisper => mlx_whisper}/transcribe.py (100%) create mode 100644 whisper/mlx_whisper/version.py rename whisper/{whisper => mlx_whisper}/whisper.py (100%) create mode 100644 whisper/setup.py diff --git a/llms/setup.py b/llms/setup.py index 26e1a3b8..c4e5d075 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import sys from pathlib import Path diff --git a/whisper/MANIFEST.in b/whisper/MANIFEST.in new file mode 100644 index 00000000..05db35a0 --- /dev/null +++ b/whisper/MANIFEST.in @@ -0,0 +1,4 @@ +include mlx_whisper/requirements.txt +include mlx_whisper/assets/mel_filters.npz +include mlx_whisper/assets/multilingual.tiktoken +include mlx_whisper/assets/gpt2.tiktoken diff --git a/whisper/README.md b/whisper/README.md index 4d41a69f..2805e899 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -6,12 +6,6 @@ parameters.[^1] ### Setup -First, install the dependencies: - -``` -pip install -r requirements.txt -``` - Install [`ffmpeg`](https://ffmpeg.org/): ``` @@ -19,19 +13,72 @@ Install [`ffmpeg`](https://ffmpeg.org/): brew install ffmpeg ``` +Install the `mlx-whisper` package with: + +``` +pip install mlx-whisper +``` + +### Run + +Transcribe audio with: + +```python +import mlx_whisper + +text = mlx_whisper.transcribe(speech_file)["text"] +``` + +The default model is "mlx-community/whisper-tiny". Choose the model by +setting `path_or_hf_repo`. For example: + +```python +result = mlx_whisper.transcribe(speech_file, path_or_hf_repo="models/large") +``` + +This will load the model contained in `models/large`. The `path_or_hf_repo` can +also point to an MLX-style Whisper model on the Hugging Face Hub. In this case, +the model will be automatically downloaded. A [collection of pre-converted +Whisper +models](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc) +are in the Hugging Face MLX Community. + +The `transcribe` function also supports word-level timestamps. You can generate +these with: + +```python +output = mlx_whisper.transcribe(speech_file, word_timestamps=True) +print(output["segments"][0]["words"]) +``` + +To see more transcription options use: + +``` +>>> help(mlx_whisper.transcribe) +``` + +### Converting models + > [!TIP] > Skip the conversion step by using pre-converted checkpoints from the Hugging > Face Hub. There are a few available in the [MLX > Community](https://huggingface.co/mlx-community) organization. -To convert a model, first download the Whisper PyTorch checkpoint and convert -the weights to the MLX format. For example, to convert the `tiny` model use: +To convert a model, first clone the MLX Examples repo: + +``` +git clone https://github.com/ml-explore/mlx-examples.git +``` + +Then run `convert.py` from `mlx-examples/whisper`. For example, to convert the +`tiny` model use: ``` python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny ``` -Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format. +Note you can also convert a local PyTorch checkpoint which is in the original +OpenAI format. To generate a 4-bit quantized model, use `-q`. For a full list of options: @@ -53,38 +100,4 @@ python convert.py --torch-name-or-path ${model} --dtype float32 --mlx-path mlx_m python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits ``` -### Run - -Transcribe audio with: - -```python -import whisper - -text = whisper.transcribe(speech_file)["text"] -``` - -Choose the model by setting `path_or_hf_repo`. For example: - -```python -result = whisper.transcribe(speech_file, path_or_hf_repo="models/large") -``` - -This will load the model contained in `models/large`. The `path_or_hf_repo` -can also point to an MLX-style Whisper model on the Hugging Face Hub. In this -case, the model will be automatically downloaded. - -The `transcribe` function also supports word-level timestamps. You can generate -these with: - -```python -output = whisper.transcribe(speech_file, word_timestamps=True) -print(output["segments"][0]["words"]) -``` - -To see more transcription options use: - -``` ->>> help(whisper.transcribe) -``` - [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 46fe4bd8..8101e343 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,14 +1,12 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import os -import subprocess import time import mlx.core as mx +from mlx_whisper import audio, decoding, load_models, transcribe -from whisper import audio, decoding, load_models, transcribe - -audio_file = "whisper/assets/ls_test.flac" +audio_file = "mlx_whisper/assets/ls_test.flac" def parse_arguments(): @@ -83,16 +81,7 @@ if __name__ == "__main__": print(f"\nFeature time {feat_time:.3f}") for model_name in models: - model_path = f"{args.mlx_dir}/{model_name}" - if not os.path.exists(model_path): - print( - f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion" - ) - subprocess.run( - f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", - shell=True, - ) - + model_path = f"mlx-community/whisper-{model_name}-mlx" print(f"\nModel: {model_name.upper()}") tokens = mx.array( [ diff --git a/whisper/convert.py b/whisper/convert.py index fd208184..37825d6c 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import copy @@ -16,11 +16,10 @@ import mlx.nn as nn import numpy as np import torch from mlx.utils import tree_flatten, tree_map, tree_unflatten +from mlx_whisper import torch_whisper +from mlx_whisper.whisper import ModelDimensions, Whisper from tqdm import tqdm -from whisper import torch_whisper -from whisper.whisper import ModelDimensions, Whisper - _VALID_DTYPES = {"float16", "float32"} _MODELS = { diff --git a/whisper/whisper/__init__.py b/whisper/mlx_whisper/__init__.py similarity index 53% rename from whisper/whisper/__init__.py rename to whisper/mlx_whisper/__init__.py index e234711c..e6de0858 100644 --- a/whisper/whisper/__init__.py +++ b/whisper/mlx_whisper/__init__.py @@ -1,4 +1,5 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. from . import audio, decoding, load_models from .transcribe import transcribe +from .version import __version__ diff --git a/whisper/whisper/assets/download_alice.sh b/whisper/mlx_whisper/assets/download_alice.sh similarity index 100% rename from whisper/whisper/assets/download_alice.sh rename to whisper/mlx_whisper/assets/download_alice.sh diff --git a/whisper/whisper/assets/gpt2.tiktoken b/whisper/mlx_whisper/assets/gpt2.tiktoken similarity index 100% rename from whisper/whisper/assets/gpt2.tiktoken rename to whisper/mlx_whisper/assets/gpt2.tiktoken diff --git a/whisper/whisper/assets/ls_test.flac b/whisper/mlx_whisper/assets/ls_test.flac similarity index 100% rename from whisper/whisper/assets/ls_test.flac rename to whisper/mlx_whisper/assets/ls_test.flac diff --git a/whisper/whisper/assets/mel_filters.npz b/whisper/mlx_whisper/assets/mel_filters.npz similarity index 100% rename from whisper/whisper/assets/mel_filters.npz rename to whisper/mlx_whisper/assets/mel_filters.npz diff --git a/whisper/whisper/assets/multilingual.tiktoken b/whisper/mlx_whisper/assets/multilingual.tiktoken similarity index 100% rename from whisper/whisper/assets/multilingual.tiktoken rename to whisper/mlx_whisper/assets/multilingual.tiktoken diff --git a/whisper/whisper/audio.py b/whisper/mlx_whisper/audio.py similarity index 100% rename from whisper/whisper/audio.py rename to whisper/mlx_whisper/audio.py diff --git a/whisper/whisper/decoding.py b/whisper/mlx_whisper/decoding.py similarity index 100% rename from whisper/whisper/decoding.py rename to whisper/mlx_whisper/decoding.py diff --git a/whisper/whisper/load_models.py b/whisper/mlx_whisper/load_models.py similarity index 100% rename from whisper/whisper/load_models.py rename to whisper/mlx_whisper/load_models.py diff --git a/whisper/requirements.txt b/whisper/mlx_whisper/requirements.txt similarity index 81% rename from whisper/requirements.txt rename to whisper/mlx_whisper/requirements.txt index cf9c92aa..ca9c148c 100644 --- a/whisper/requirements.txt +++ b/whisper/mlx_whisper/requirements.txt @@ -4,6 +4,6 @@ numpy torch tqdm more-itertools -tiktoken==0.3.3 +tiktoken huggingface_hub scipy diff --git a/whisper/whisper/timing.py b/whisper/mlx_whisper/timing.py similarity index 100% rename from whisper/whisper/timing.py rename to whisper/mlx_whisper/timing.py diff --git a/whisper/whisper/tokenizer.py b/whisper/mlx_whisper/tokenizer.py similarity index 100% rename from whisper/whisper/tokenizer.py rename to whisper/mlx_whisper/tokenizer.py diff --git a/whisper/whisper/torch_whisper.py b/whisper/mlx_whisper/torch_whisper.py similarity index 100% rename from whisper/whisper/torch_whisper.py rename to whisper/mlx_whisper/torch_whisper.py diff --git a/whisper/whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py similarity index 100% rename from whisper/whisper/transcribe.py rename to whisper/mlx_whisper/transcribe.py diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py new file mode 100644 index 00000000..87ee07a7 --- /dev/null +++ b/whisper/mlx_whisper/version.py @@ -0,0 +1,3 @@ +# Copyright © 2023-2024 Apple Inc. + +__version__ = "0.1.0" diff --git a/whisper/whisper/whisper.py b/whisper/mlx_whisper/whisper.py similarity index 100% rename from whisper/whisper/whisper.py rename to whisper/mlx_whisper/whisper.py diff --git a/whisper/setup.py b/whisper/setup.py new file mode 100644 index 00000000..eaab22d4 --- /dev/null +++ b/whisper/setup.py @@ -0,0 +1,32 @@ +# Copyright © 2024 Apple Inc. + +import sys +from pathlib import Path + +from setuptools import find_packages, setup + +package_dir = Path(__file__).parent / "mlx_whisper" + +with open(package_dir / "requirements.txt") as fid: + requirements = [l.strip() for l in fid.readlines()] + +sys.path.append(str(package_dir)) + +from version import __version__ + +setup( + name="mlx-whisper", + version=__version__, + description="OpenAI Whisper on Apple silicon with MLX and the Hugging Face Hub", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + readme="README.md", + author_email="mlx@group.apple.com", + author="MLX Contributors", + url="https://github.com/ml-explore/mlx-examples", + license="MIT", + install_requires=requirements, + packages=find_packages(), + include_package_data=True, + python_requires=">=3.8", +) diff --git a/whisper/test.py b/whisper/test.py index 3be1c27d..ce559251 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import json import os @@ -7,21 +7,20 @@ from dataclasses import asdict from pathlib import Path import mlx.core as mx +import mlx_whisper +import mlx_whisper.audio as audio +import mlx_whisper.decoding as decoding +import mlx_whisper.load_models as load_models import numpy as np import torch from convert import load_torch_model, quantize, torch_to_mlx from mlx.utils import tree_flatten -import whisper -import whisper.audio as audio -import whisper.decoding as decoding -import whisper.load_models as load_models - MODEL_NAME = "tiny" MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32" MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16" MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits" -TEST_AUDIO = "whisper/assets/ls_test.flac" +TEST_AUDIO = "mlx_whisper/assets/ls_test.flac" def _save_model(save_dir, weights, config): @@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe( + result = mlx_whisper.transcribe( TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual( @@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe( + result = mlx_whisper.transcribe( audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual(len(result["text"]), 10920) @@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase): check_segment(result["segments"][73], expected_73) def test_transcribe_word_level_timestamps_confidence_scores(self): - result = whisper.transcribe( + result = mlx_whisper.transcribe( TEST_AUDIO, path_or_hf_repo=MLX_FP16_MODEL_PATH, word_timestamps=True,