[Lora] Fix generate (#282)

* fix generate

* update readme, fix test, better default

* nits

* typo
This commit is contained in:
Awni Hannun
2024-01-10 16:13:06 -08:00
committed by GitHub
parent a2bc8426f2
commit 80d18671ad
5 changed files with 25 additions and 16 deletions

View File

@@ -19,8 +19,13 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```
Next, download the Whisper PyTorch checkpoint and convert the weights to the
MLX format. For example, to convert the `tiny` model use:
> [!TIP]
> Skip the conversion step by using pre-converted checkpoints from the Hugging
> Face Hub. There are a few available in the [MLX
> Community](https://huggingface.co/mlx-community) organization.
To convert a model, first download the Whisper PyTorch checkpoint and convert
the weights to the MLX format. For example, to convert the `tiny` model use:
```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
@@ -34,13 +39,8 @@ To generate a 4-bit quantized model, use `-q`. For a full list of options:
python convert.py --help
```
By default, the conversion script will make the directory `mlx_models/tiny` and save
the converted `weights.npz` and `config.json` there.
> [!TIP]
> Alternatively, you can also download a few converted checkpoints from the
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
> Face and skip the conversion step.
By default, the conversion script will make the directory `mlx_models/tiny`
and save the converted `weights.npz` and `config.json` there.
### Run
@@ -52,6 +52,16 @@ import whisper
text = whisper.transcribe(speech_file)["text"]
```
Choose the model by setting `hf_path_or_repo`. For example:
```python
result = whisper.transcribe(speech_file, hf_path_or_repo="models/large")
```
This will load the model contained in `models/large`. The `hf_path_or_repo`
can also point to an MLX-style Whisper model on the Hugging Face Hub. In this
case, the model will be automatically downloaded.
The `transcribe` function also supports word-level timestamps. You can generate
these with:

View File

@@ -189,7 +189,7 @@ class TestWhisper(unittest.TestCase):
def test_transcribe(self):
result = whisper.transcribe(
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
result["text"],
@@ -210,7 +210,7 @@ class TestWhisper(unittest.TestCase):
return
result = whisper.transcribe(
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
@@ -313,9 +313,8 @@ class TestWhisper(unittest.TestCase):
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
path_or_hf_repo=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)

View File

@@ -62,7 +62,7 @@ class ModelHolder:
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
path_or_hf_repo: str = "mlx_models",
path_or_hf_repo: str = "mlx-community/whisper-tiny",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,