mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Whisper: Add pip distribution configuration to support pip installations. (#739)
* Whisper: rename whisper to mlx_whisper * Whisper: add setup.py config for publish * Whisper: add assets data to setup config * Whisper: pre-commit for setup.py * Whisper: Update README.md * Whisper: Update README.md * nits * fix package data * nit in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
4bf2eb17f2
commit
6775d6cb3f
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
4
whisper/MANIFEST.in
Normal file
4
whisper/MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
include mlx_whisper/requirements.txt
|
||||||
|
include mlx_whisper/assets/mel_filters.npz
|
||||||
|
include mlx_whisper/assets/multilingual.tiktoken
|
||||||
|
include mlx_whisper/assets/gpt2.tiktoken
|
@ -6,12 +6,6 @@ parameters.[^1]
|
|||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
First, install the dependencies:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
Install [`ffmpeg`](https://ffmpeg.org/):
|
Install [`ffmpeg`](https://ffmpeg.org/):
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -19,19 +13,72 @@ Install [`ffmpeg`](https://ffmpeg.org/):
|
|||||||
brew install ffmpeg
|
brew install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install the `mlx-whisper` package with:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install mlx-whisper
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
|
Transcribe audio with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import mlx_whisper
|
||||||
|
|
||||||
|
text = mlx_whisper.transcribe(speech_file)["text"]
|
||||||
|
```
|
||||||
|
|
||||||
|
The default model is "mlx-community/whisper-tiny". Choose the model by
|
||||||
|
setting `path_or_hf_repo`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = mlx_whisper.transcribe(speech_file, path_or_hf_repo="models/large")
|
||||||
|
```
|
||||||
|
|
||||||
|
This will load the model contained in `models/large`. The `path_or_hf_repo` can
|
||||||
|
also point to an MLX-style Whisper model on the Hugging Face Hub. In this case,
|
||||||
|
the model will be automatically downloaded. A [collection of pre-converted
|
||||||
|
Whisper
|
||||||
|
models](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc)
|
||||||
|
are in the Hugging Face MLX Community.
|
||||||
|
|
||||||
|
The `transcribe` function also supports word-level timestamps. You can generate
|
||||||
|
these with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
output = mlx_whisper.transcribe(speech_file, word_timestamps=True)
|
||||||
|
print(output["segments"][0]["words"])
|
||||||
|
```
|
||||||
|
|
||||||
|
To see more transcription options use:
|
||||||
|
|
||||||
|
```
|
||||||
|
>>> help(mlx_whisper.transcribe)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Converting models
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Skip the conversion step by using pre-converted checkpoints from the Hugging
|
> Skip the conversion step by using pre-converted checkpoints from the Hugging
|
||||||
> Face Hub. There are a few available in the [MLX
|
> Face Hub. There are a few available in the [MLX
|
||||||
> Community](https://huggingface.co/mlx-community) organization.
|
> Community](https://huggingface.co/mlx-community) organization.
|
||||||
|
|
||||||
To convert a model, first download the Whisper PyTorch checkpoint and convert
|
To convert a model, first clone the MLX Examples repo:
|
||||||
the weights to the MLX format. For example, to convert the `tiny` model use:
|
|
||||||
|
```
|
||||||
|
git clone https://github.com/ml-explore/mlx-examples.git
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run `convert.py` from `mlx-examples/whisper`. For example, to convert the
|
||||||
|
`tiny` model use:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
||||||
```
|
```
|
||||||
|
|
||||||
Note you can also convert a local PyTorch checkpoint which is in the original OpenAI format.
|
Note you can also convert a local PyTorch checkpoint which is in the original
|
||||||
|
OpenAI format.
|
||||||
|
|
||||||
To generate a 4-bit quantized model, use `-q`. For a full list of options:
|
To generate a 4-bit quantized model, use `-q`. For a full list of options:
|
||||||
|
|
||||||
@ -53,38 +100,4 @@ python convert.py --torch-name-or-path ${model} --dtype float32 --mlx-path mlx_m
|
|||||||
python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits
|
python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run
|
|
||||||
|
|
||||||
Transcribe audio with:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import whisper
|
|
||||||
|
|
||||||
text = whisper.transcribe(speech_file)["text"]
|
|
||||||
```
|
|
||||||
|
|
||||||
Choose the model by setting `path_or_hf_repo`. For example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
result = whisper.transcribe(speech_file, path_or_hf_repo="models/large")
|
|
||||||
```
|
|
||||||
|
|
||||||
This will load the model contained in `models/large`. The `path_or_hf_repo`
|
|
||||||
can also point to an MLX-style Whisper model on the Hugging Face Hub. In this
|
|
||||||
case, the model will be automatically downloaded.
|
|
||||||
|
|
||||||
The `transcribe` function also supports word-level timestamps. You can generate
|
|
||||||
these with:
|
|
||||||
|
|
||||||
```python
|
|
||||||
output = whisper.transcribe(speech_file, word_timestamps=True)
|
|
||||||
print(output["segments"][0]["words"])
|
|
||||||
```
|
|
||||||
|
|
||||||
To see more transcription options use:
|
|
||||||
|
|
||||||
```
|
|
||||||
>>> help(whisper.transcribe)
|
|
||||||
```
|
|
||||||
|
|
||||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
|
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx_whisper import audio, decoding, load_models, transcribe
|
||||||
|
|
||||||
from whisper import audio, decoding, load_models, transcribe
|
audio_file = "mlx_whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
audio_file = "whisper/assets/ls_test.flac"
|
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
@ -83,16 +81,7 @@ if __name__ == "__main__":
|
|||||||
print(f"\nFeature time {feat_time:.3f}")
|
print(f"\nFeature time {feat_time:.3f}")
|
||||||
|
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
model_path = f"{args.mlx_dir}/{model_name}"
|
model_path = f"mlx-community/whisper-{model_name}-mlx"
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(
|
|
||||||
f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion"
|
|
||||||
)
|
|
||||||
subprocess.run(
|
|
||||||
f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}",
|
|
||||||
shell=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\nModel: {model_name.upper()}")
|
print(f"\nModel: {model_name.upper()}")
|
||||||
tokens = mx.array(
|
tokens = mx.array(
|
||||||
[
|
[
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
@ -16,11 +16,10 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
from mlx_whisper import torch_whisper
|
||||||
|
from mlx_whisper.whisper import ModelDimensions, Whisper
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from whisper import torch_whisper
|
|
||||||
from whisper.whisper import ModelDimensions, Whisper
|
|
||||||
|
|
||||||
_VALID_DTYPES = {"float16", "float32"}
|
_VALID_DTYPES = {"float16", "float32"}
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from . import audio, decoding, load_models
|
from . import audio, decoding, load_models
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
|
from .version import __version__
|
@ -4,6 +4,6 @@ numpy
|
|||||||
torch
|
torch
|
||||||
tqdm
|
tqdm
|
||||||
more-itertools
|
more-itertools
|
||||||
tiktoken==0.3.3
|
tiktoken
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
scipy
|
scipy
|
3
whisper/mlx_whisper/version.py
Normal file
3
whisper/mlx_whisper/version.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
32
whisper/setup.py
Normal file
32
whisper/setup.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
package_dir = Path(__file__).parent / "mlx_whisper"
|
||||||
|
|
||||||
|
with open(package_dir / "requirements.txt") as fid:
|
||||||
|
requirements = [l.strip() for l in fid.readlines()]
|
||||||
|
|
||||||
|
sys.path.append(str(package_dir))
|
||||||
|
|
||||||
|
from version import __version__
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="mlx-whisper",
|
||||||
|
version=__version__,
|
||||||
|
description="OpenAI Whisper on Apple silicon with MLX and the Hugging Face Hub",
|
||||||
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
readme="README.md",
|
||||||
|
author_email="mlx@group.apple.com",
|
||||||
|
author="MLX Contributors",
|
||||||
|
url="https://github.com/ml-explore/mlx-examples",
|
||||||
|
license="MIT",
|
||||||
|
install_requires=requirements,
|
||||||
|
packages=find_packages(),
|
||||||
|
include_package_data=True,
|
||||||
|
python_requires=">=3.8",
|
||||||
|
)
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -7,21 +7,20 @@ from dataclasses import asdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx_whisper
|
||||||
|
import mlx_whisper.audio as audio
|
||||||
|
import mlx_whisper.decoding as decoding
|
||||||
|
import mlx_whisper.load_models as load_models
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from convert import load_torch_model, quantize, torch_to_mlx
|
from convert import load_torch_model, quantize, torch_to_mlx
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
import whisper
|
|
||||||
import whisper.audio as audio
|
|
||||||
import whisper.decoding as decoding
|
|
||||||
import whisper.load_models as load_models
|
|
||||||
|
|
||||||
MODEL_NAME = "tiny"
|
MODEL_NAME = "tiny"
|
||||||
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||||
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||||
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
|
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
|
||||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
TEST_AUDIO = "mlx_whisper/assets/ls_test.flac"
|
||||||
|
|
||||||
|
|
||||||
def _save_model(save_dir, weights, config):
|
def _save_model(save_dir, weights, config):
|
||||||
@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||||
|
|
||||||
def test_transcribe(self):
|
def test_transcribe(self):
|
||||||
result = whisper.transcribe(
|
result = mlx_whisper.transcribe(
|
||||||
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||||
return
|
return
|
||||||
|
|
||||||
result = whisper.transcribe(
|
result = mlx_whisper.transcribe(
|
||||||
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||||
)
|
)
|
||||||
self.assertEqual(len(result["text"]), 10920)
|
self.assertEqual(len(result["text"]), 10920)
|
||||||
@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
check_segment(result["segments"][73], expected_73)
|
check_segment(result["segments"][73], expected_73)
|
||||||
|
|
||||||
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||||
result = whisper.transcribe(
|
result = mlx_whisper.transcribe(
|
||||||
TEST_AUDIO,
|
TEST_AUDIO,
|
||||||
path_or_hf_repo=MLX_FP16_MODEL_PATH,
|
path_or_hf_repo=MLX_FP16_MODEL_PATH,
|
||||||
word_timestamps=True,
|
word_timestamps=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user