diff --git a/llms/setup.py b/llms/setup.py index 26e1a3b8..c4e5d075 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import sys from pathlib import Path diff --git a/whisper/MANIFEST.in b/whisper/MANIFEST.in new file mode 100644 index 00000000..05db35a0 --- /dev/null +++ b/whisper/MANIFEST.in @@ -0,0 +1,4 @@ +include mlx_whisper/requirements.txt +include mlx_whisper/assets/mel_filters.npz +include mlx_whisper/assets/multilingual.tiktoken +include mlx_whisper/assets/gpt2.tiktoken diff --git a/whisper/README.md b/whisper/README.md index 4d41a69f..2805e899 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -6,12 +6,6 @@ parameters.[^1] ### Setup -First, install the dependencies: - -``` -pip install -r requirements.txt -``` - Install [`ffmpeg`](https://ffmpeg.org/): ``` @@ -19,19 +13,72 @@ Install [`ffmpeg`](https://ffmpeg.org/): brew install ffmpeg ``` +Install the `mlx-whisper` package with: + +``` +pip install mlx-whisper +``` + +### Run + +Transcribe audio with: + +```python +import mlx_whisper + +text = mlx_whisper.transcribe(speech_file)["text"] +``` + +The default model is "mlx-community/whisper-tiny". Choose the model by +setting `path_or_hf_repo`. For example: + +```python +result = mlx_whisper.transcribe(speech_file, path_or_hf_repo="models/large") +``` + +This will load the model contained in `models/large`. The `path_or_hf_repo` can +also point to an MLX-style Whisper model on the Hugging Face Hub. In this case, +the model will be automatically downloaded. A [collection of pre-converted +Whisper +models](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc) +are in the Hugging Face MLX Community. + +The `transcribe` function also supports word-level timestamps. You can generate +these with: + +```python +output = mlx_whisper.transcribe(speech_file, word_timestamps=True) +print(output["segments"][0]["words"]) +``` + +To see more transcription options use: + +``` +>>> help(mlx_whisper.transcribe) +``` + +### Converting models + > [!TIP] > Skip the conversion step by using pre-converted checkpoints from the Hugging > Face Hub. There are a few available in the [MLX > Community](https://huggingface.co/mlx-community) organization. -To convert a model, first download the Whisper PyTorch checkpoint and convert -the weights to the MLX format. For example, to convert the `tiny` model use: +To convert a model, first clone the MLX Examples repo: + +``` +git clone https://github.com/ml-explore/mlx-examples.git +``` + +Then run `convert.py` from `mlx-examples/whisper`. For example, to convert the +`tiny` model use: ``` python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny ``` -Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format. +Note you can also convert a local PyTorch checkpoint which is in the original +OpenAI format. To generate a 4-bit quantized model, use `-q`. For a full list of options: @@ -53,38 +100,4 @@ python convert.py --torch-name-or-path ${model} --dtype float32 --mlx-path mlx_m python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits ``` -### Run - -Transcribe audio with: - -```python -import whisper - -text = whisper.transcribe(speech_file)["text"] -``` - -Choose the model by setting `path_or_hf_repo`. For example: - -```python -result = whisper.transcribe(speech_file, path_or_hf_repo="models/large") -``` - -This will load the model contained in `models/large`. The `path_or_hf_repo` -can also point to an MLX-style Whisper model on the Hugging Face Hub. In this -case, the model will be automatically downloaded. - -The `transcribe` function also supports word-level timestamps. You can generate -these with: - -```python -output = whisper.transcribe(speech_file, word_timestamps=True) -print(output["segments"][0]["words"]) -``` - -To see more transcription options use: - -``` ->>> help(whisper.transcribe) -``` - [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 46fe4bd8..8101e343 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -1,14 +1,12 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import os -import subprocess import time import mlx.core as mx +from mlx_whisper import audio, decoding, load_models, transcribe -from whisper import audio, decoding, load_models, transcribe - -audio_file = "whisper/assets/ls_test.flac" +audio_file = "mlx_whisper/assets/ls_test.flac" def parse_arguments(): @@ -83,16 +81,7 @@ if __name__ == "__main__": print(f"\nFeature time {feat_time:.3f}") for model_name in models: - model_path = f"{args.mlx_dir}/{model_name}" - if not os.path.exists(model_path): - print( - f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion" - ) - subprocess.run( - f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", - shell=True, - ) - + model_path = f"mlx-community/whisper-{model_name}-mlx" print(f"\nModel: {model_name.upper()}") tokens = mx.array( [ diff --git a/whisper/convert.py b/whisper/convert.py index fd208184..37825d6c 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import argparse import copy @@ -16,11 +16,10 @@ import mlx.nn as nn import numpy as np import torch from mlx.utils import tree_flatten, tree_map, tree_unflatten +from mlx_whisper import torch_whisper +from mlx_whisper.whisper import ModelDimensions, Whisper from tqdm import tqdm -from whisper import torch_whisper -from whisper.whisper import ModelDimensions, Whisper - _VALID_DTYPES = {"float16", "float32"} _MODELS = { diff --git a/whisper/whisper/__init__.py b/whisper/mlx_whisper/__init__.py similarity index 53% rename from whisper/whisper/__init__.py rename to whisper/mlx_whisper/__init__.py index e234711c..e6de0858 100644 --- a/whisper/whisper/__init__.py +++ b/whisper/mlx_whisper/__init__.py @@ -1,4 +1,5 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. from . import audio, decoding, load_models from .transcribe import transcribe +from .version import __version__ diff --git a/whisper/whisper/assets/download_alice.sh b/whisper/mlx_whisper/assets/download_alice.sh similarity index 100% rename from whisper/whisper/assets/download_alice.sh rename to whisper/mlx_whisper/assets/download_alice.sh diff --git a/whisper/whisper/assets/gpt2.tiktoken b/whisper/mlx_whisper/assets/gpt2.tiktoken similarity index 100% rename from whisper/whisper/assets/gpt2.tiktoken rename to whisper/mlx_whisper/assets/gpt2.tiktoken diff --git a/whisper/whisper/assets/ls_test.flac b/whisper/mlx_whisper/assets/ls_test.flac similarity index 100% rename from whisper/whisper/assets/ls_test.flac rename to whisper/mlx_whisper/assets/ls_test.flac diff --git a/whisper/whisper/assets/mel_filters.npz b/whisper/mlx_whisper/assets/mel_filters.npz similarity index 100% rename from whisper/whisper/assets/mel_filters.npz rename to whisper/mlx_whisper/assets/mel_filters.npz diff --git a/whisper/whisper/assets/multilingual.tiktoken b/whisper/mlx_whisper/assets/multilingual.tiktoken similarity index 100% rename from whisper/whisper/assets/multilingual.tiktoken rename to whisper/mlx_whisper/assets/multilingual.tiktoken diff --git a/whisper/whisper/audio.py b/whisper/mlx_whisper/audio.py similarity index 100% rename from whisper/whisper/audio.py rename to whisper/mlx_whisper/audio.py diff --git a/whisper/whisper/decoding.py b/whisper/mlx_whisper/decoding.py similarity index 100% rename from whisper/whisper/decoding.py rename to whisper/mlx_whisper/decoding.py diff --git a/whisper/whisper/load_models.py b/whisper/mlx_whisper/load_models.py similarity index 100% rename from whisper/whisper/load_models.py rename to whisper/mlx_whisper/load_models.py diff --git a/whisper/requirements.txt b/whisper/mlx_whisper/requirements.txt similarity index 81% rename from whisper/requirements.txt rename to whisper/mlx_whisper/requirements.txt index cf9c92aa..ca9c148c 100644 --- a/whisper/requirements.txt +++ b/whisper/mlx_whisper/requirements.txt @@ -4,6 +4,6 @@ numpy torch tqdm more-itertools -tiktoken==0.3.3 +tiktoken huggingface_hub scipy diff --git a/whisper/whisper/timing.py b/whisper/mlx_whisper/timing.py similarity index 100% rename from whisper/whisper/timing.py rename to whisper/mlx_whisper/timing.py diff --git a/whisper/whisper/tokenizer.py b/whisper/mlx_whisper/tokenizer.py similarity index 100% rename from whisper/whisper/tokenizer.py rename to whisper/mlx_whisper/tokenizer.py diff --git a/whisper/whisper/torch_whisper.py b/whisper/mlx_whisper/torch_whisper.py similarity index 100% rename from whisper/whisper/torch_whisper.py rename to whisper/mlx_whisper/torch_whisper.py diff --git a/whisper/whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py similarity index 100% rename from whisper/whisper/transcribe.py rename to whisper/mlx_whisper/transcribe.py diff --git a/whisper/mlx_whisper/version.py b/whisper/mlx_whisper/version.py new file mode 100644 index 00000000..87ee07a7 --- /dev/null +++ b/whisper/mlx_whisper/version.py @@ -0,0 +1,3 @@ +# Copyright © 2023-2024 Apple Inc. + +__version__ = "0.1.0" diff --git a/whisper/whisper/whisper.py b/whisper/mlx_whisper/whisper.py similarity index 100% rename from whisper/whisper/whisper.py rename to whisper/mlx_whisper/whisper.py diff --git a/whisper/setup.py b/whisper/setup.py new file mode 100644 index 00000000..eaab22d4 --- /dev/null +++ b/whisper/setup.py @@ -0,0 +1,32 @@ +# Copyright © 2024 Apple Inc. + +import sys +from pathlib import Path + +from setuptools import find_packages, setup + +package_dir = Path(__file__).parent / "mlx_whisper" + +with open(package_dir / "requirements.txt") as fid: + requirements = [l.strip() for l in fid.readlines()] + +sys.path.append(str(package_dir)) + +from version import __version__ + +setup( + name="mlx-whisper", + version=__version__, + description="OpenAI Whisper on Apple silicon with MLX and the Hugging Face Hub", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + readme="README.md", + author_email="mlx@group.apple.com", + author="MLX Contributors", + url="https://github.com/ml-explore/mlx-examples", + license="MIT", + install_requires=requirements, + packages=find_packages(), + include_package_data=True, + python_requires=">=3.8", +) diff --git a/whisper/test.py b/whisper/test.py index 3be1c27d..ce559251 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import json import os @@ -7,21 +7,20 @@ from dataclasses import asdict from pathlib import Path import mlx.core as mx +import mlx_whisper +import mlx_whisper.audio as audio +import mlx_whisper.decoding as decoding +import mlx_whisper.load_models as load_models import numpy as np import torch from convert import load_torch_model, quantize, torch_to_mlx from mlx.utils import tree_flatten -import whisper -import whisper.audio as audio -import whisper.decoding as decoding -import whisper.load_models as load_models - MODEL_NAME = "tiny" MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32" MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16" MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits" -TEST_AUDIO = "whisper/assets/ls_test.flac" +TEST_AUDIO = "mlx_whisper/assets/ls_test.flac" def _save_model(save_dir, weights, config): @@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe( + result = mlx_whisper.transcribe( TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual( @@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe( + result = mlx_whisper.transcribe( audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual(len(result["text"]), 10920) @@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase): check_segment(result["segments"][73], expected_73) def test_transcribe_word_level_timestamps_confidence_scores(self): - result = whisper.transcribe( + result = mlx_whisper.transcribe( TEST_AUDIO, path_or_hf_repo=MLX_FP16_MODEL_PATH, word_timestamps=True,