mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-25 13:51:15 +08:00
[Lora] Fix generate (#282)
* fix generate * update readme, fix test, better default * nits * typo
This commit is contained in:
parent
a2bc8426f2
commit
80d18671ad
@ -110,7 +110,7 @@ For generation use:
|
|||||||
```
|
```
|
||||||
python lora.py --model <path_to_model> \
|
python lora.py --model <path_to_model> \
|
||||||
--adapter-file <path_to_adapters.npz> \
|
--adapter-file <path_to_adapters.npz> \
|
||||||
--num-tokens 50 \
|
--max-tokens 50 \
|
||||||
--prompt "table: 1-10015132-16
|
--prompt "table: 1-10015132-16
|
||||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||||
Q: What is terrence ross' nationality
|
Q: What is terrence ross' nationality
|
||||||
|
@ -265,7 +265,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
|||||||
def generate(model, prompt, tokenizer, args):
|
def generate(model, prompt, tokenizer, args):
|
||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
|
|
||||||
prompt = tokenizer.encode(args.prompt)
|
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
skip = 0
|
skip = 0
|
||||||
|
@ -19,8 +19,13 @@ Install [`ffmpeg`](https://ffmpeg.org/):
|
|||||||
brew install ffmpeg
|
brew install ffmpeg
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, download the Whisper PyTorch checkpoint and convert the weights to the
|
> [!TIP]
|
||||||
MLX format. For example, to convert the `tiny` model use:
|
> Skip the conversion step by using pre-converted checkpoints from the Hugging
|
||||||
|
> Face Hub. There are a few available in the [MLX
|
||||||
|
> Community](https://huggingface.co/mlx-community) organization.
|
||||||
|
|
||||||
|
To convert a model, first download the Whisper PyTorch checkpoint and convert
|
||||||
|
the weights to the MLX format. For example, to convert the `tiny` model use:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
|
||||||
@ -34,13 +39,8 @@ To generate a 4-bit quantized model, use `-q`. For a full list of options:
|
|||||||
python convert.py --help
|
python convert.py --help
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, the conversion script will make the directory `mlx_models/tiny` and save
|
By default, the conversion script will make the directory `mlx_models/tiny`
|
||||||
the converted `weights.npz` and `config.json` there.
|
and save the converted `weights.npz` and `config.json` there.
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> Alternatively, you can also download a few converted checkpoints from the
|
|
||||||
> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging
|
|
||||||
> Face and skip the conversion step.
|
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
@ -52,6 +52,16 @@ import whisper
|
|||||||
text = whisper.transcribe(speech_file)["text"]
|
text = whisper.transcribe(speech_file)["text"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Choose the model by setting `hf_path_or_repo`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = whisper.transcribe(speech_file, hf_path_or_repo="models/large")
|
||||||
|
```
|
||||||
|
|
||||||
|
This will load the model contained in `models/large`. The `hf_path_or_repo`
|
||||||
|
can also point to an MLX-style Whisper model on the Hugging Face Hub. In this
|
||||||
|
case, the model will be automatically downloaded.
|
||||||
|
|
||||||
The `transcribe` function also supports word-level timestamps. You can generate
|
The `transcribe` function also supports word-level timestamps. You can generate
|
||||||
these with:
|
these with:
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
|
|
||||||
def test_transcribe(self):
|
def test_transcribe(self):
|
||||||
result = whisper.transcribe(
|
result = whisper.transcribe(
|
||||||
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result["text"],
|
result["text"],
|
||||||
@ -210,7 +210,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
result = whisper.transcribe(
|
result = whisper.transcribe(
|
||||||
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
|
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||||
)
|
)
|
||||||
self.assertEqual(len(result["text"]), 10920)
|
self.assertEqual(len(result["text"]), 10920)
|
||||||
self.assertEqual(result["language"], "en")
|
self.assertEqual(result["language"], "en")
|
||||||
@ -313,9 +313,8 @@ class TestWhisper(unittest.TestCase):
|
|||||||
|
|
||||||
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||||
result = whisper.transcribe(
|
result = whisper.transcribe(
|
||||||
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
|
|
||||||
TEST_AUDIO,
|
TEST_AUDIO,
|
||||||
model_path=MLX_FP16_MODEL_PATH,
|
path_or_hf_repo=MLX_FP16_MODEL_PATH,
|
||||||
word_timestamps=True,
|
word_timestamps=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class ModelHolder:
|
|||||||
def transcribe(
|
def transcribe(
|
||||||
audio: Union[str, np.ndarray, mx.array],
|
audio: Union[str, np.ndarray, mx.array],
|
||||||
*,
|
*,
|
||||||
path_or_hf_repo: str = "mlx_models",
|
path_or_hf_repo: str = "mlx-community/whisper-tiny",
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
|
Loading…
Reference in New Issue
Block a user