Whisper: Add pip distribution configuration to support pip installations. (#739)

* Whisper: rename whisper to mlx_whisper

* Whisper: add setup.py config for publish

* Whisper: add assets data to setup config

* Whisper: pre-commit for setup.py

* Whisper: Update README.md

* Whisper: Update README.md

* nits

* fix package data

* nit in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid 2024-05-02 00:00:02 +08:00 committed by GitHub
parent 4bf2eb17f2
commit 6775d6cb3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 116 additions and 74 deletions

View File

@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
import sys
from pathlib import Path

4
whisper/MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include mlx_whisper/requirements.txt
include mlx_whisper/assets/mel_filters.npz
include mlx_whisper/assets/multilingual.tiktoken
include mlx_whisper/assets/gpt2.tiktoken

View File

@ -6,12 +6,6 @@ parameters.[^1]
### Setup
First, install the dependencies:
```
pip install -r requirements.txt
```
Install [`ffmpeg`](https://ffmpeg.org/):
```
@ -19,19 +13,72 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```
Install the `mlx-whisper` package with:
```
pip install mlx-whisper
```
### Run
Transcribe audio with:
```python
import mlx_whisper
text = mlx_whisper.transcribe(speech_file)["text"]
```
The default model is "mlx-community/whisper-tiny". Choose the model by
setting `path_or_hf_repo`. For example:
```python
result = mlx_whisper.transcribe(speech_file, path_or_hf_repo="models/large")
```
This will load the model contained in `models/large`. The `path_or_hf_repo` can
also point to an MLX-style Whisper model on the Hugging Face Hub. In this case,
the model will be automatically downloaded. A [collection of pre-converted
Whisper
models](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc)
are in the Hugging Face MLX Community.
The `transcribe` function also supports word-level timestamps. You can generate
these with:
```python
output = mlx_whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```
To see more transcription options use:
```
>>> help(mlx_whisper.transcribe)
```
### Converting models
> [!TIP]
> Skip the conversion step by using pre-converted checkpoints from the Hugging
> Face Hub. There are a few available in the [MLX
> Community](https://huggingface.co/mlx-community) organization.
To convert a model, first download the Whisper PyTorch checkpoint and convert
the weights to the MLX format. For example, to convert the `tiny` model use:
To convert a model, first clone the MLX Examples repo:
```
git clone https://github.com/ml-explore/mlx-examples.git
```
Then run `convert.py` from `mlx-examples/whisper`. For example, to convert the
`tiny` model use:
```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
```
Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format.
Note you can also convert a local PyTorch checkpoint which is in the original
OpenAI format.
To generate a 4-bit quantized model, use `-q`. For a full list of options:
@ -53,38 +100,4 @@ python convert.py --torch-name-or-path ${model} --dtype float32 --mlx-path mlx_m
python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits
```
### Run
Transcribe audio with:
```python
import whisper
text = whisper.transcribe(speech_file)["text"]
```
Choose the model by setting `path_or_hf_repo`. For example:
```python
result = whisper.transcribe(speech_file, path_or_hf_repo="models/large")
```
This will load the model contained in `models/large`. The `path_or_hf_repo`
can also point to an MLX-style Whisper model on the Hugging Face Hub. In this
case, the model will be automatically downloaded.
The `transcribe` function also supports word-level timestamps. You can generate
these with:
```python
output = whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```
To see more transcription options use:
```
>>> help(whisper.transcribe)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.

View File

@ -1,14 +1,12 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import argparse
import os
import subprocess
import time
import mlx.core as mx
from mlx_whisper import audio, decoding, load_models, transcribe
from whisper import audio, decoding, load_models, transcribe
audio_file = "whisper/assets/ls_test.flac"
audio_file = "mlx_whisper/assets/ls_test.flac"
def parse_arguments():
@ -83,16 +81,7 @@ if __name__ == "__main__":
print(f"\nFeature time {feat_time:.3f}")
for model_name in models:
model_path = f"{args.mlx_dir}/{model_name}"
if not os.path.exists(model_path):
print(
f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion"
)
subprocess.run(
f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}",
shell=True,
)
model_path = f"mlx-community/whisper-{model_name}-mlx"
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(
[

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import argparse
import copy
@ -16,11 +16,10 @@ import mlx.nn as nn
import numpy as np
import torch
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx_whisper import torch_whisper
from mlx_whisper.whisper import ModelDimensions, Whisper
from tqdm import tqdm
from whisper import torch_whisper
from whisper.whisper import ModelDimensions, Whisper
_VALID_DTYPES = {"float16", "float32"}
_MODELS = {

View File

@ -1,4 +1,5 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from . import audio, decoding, load_models
from .transcribe import transcribe
from .version import __version__

View File

@ -4,6 +4,6 @@ numpy
torch
tqdm
more-itertools
tiktoken==0.3.3
tiktoken
huggingface_hub
scipy

View File

@ -0,0 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.1.0"

32
whisper/setup.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright © 2024 Apple Inc.
import sys
from pathlib import Path
from setuptools import find_packages, setup
package_dir = Path(__file__).parent / "mlx_whisper"
with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from version import __version__
setup(
name="mlx-whisper",
version=__version__,
description="OpenAI Whisper on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
author_email="mlx@group.apple.com",
author="MLX Contributors",
url="https://github.com/ml-explore/mlx-examples",
license="MIT",
install_requires=requirements,
packages=find_packages(),
include_package_data=True,
python_requires=">=3.8",
)

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import json
import os
@ -7,21 +7,20 @@ from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import mlx_whisper
import mlx_whisper.audio as audio
import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "whisper/assets/ls_test.flac"
TEST_AUDIO = "mlx_whisper/assets/ls_test.flac"
def _save_model(save_dir, weights, config):
@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(
result = mlx_whisper.transcribe(
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(
result = mlx_whisper.transcribe(
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
result = mlx_whisper.transcribe(
TEST_AUDIO,
path_or_hf_repo=MLX_FP16_MODEL_PATH,
word_timestamps=True,