diff --git a/lora/README.md b/lora/README.md index c4a80341..a3476cbd 100644 --- a/lora/README.md +++ b/lora/README.md @@ -110,7 +110,7 @@ For generation use: ``` python lora.py --model \ --adapter-file \ - --num-tokens 50 \ + --max-tokens 50 \ --prompt "table: 1-10015132-16 columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team Q: What is terrence ross' nationality diff --git a/lora/lora.py b/lora/lora.py index 6a79d239..eb672996 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -265,7 +265,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): def generate(model, prompt, tokenizer, args): print(args.prompt, end="", flush=True) - prompt = tokenizer.encode(args.prompt) + prompt = mx.array(tokenizer.encode(args.prompt)) tokens = [] skip = 0 diff --git a/whisper/README.md b/whisper/README.md index 071b3fc4..e785d9bb 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -19,8 +19,13 @@ Install [`ffmpeg`](https://ffmpeg.org/): brew install ffmpeg ``` -Next, download the Whisper PyTorch checkpoint and convert the weights to the -MLX format. For example, to convert the `tiny` model use: +> [!TIP] +> Skip the conversion step by using pre-converted checkpoints from the Hugging +> Face Hub. There are a few available in the [MLX +> Community](https://huggingface.co/mlx-community) organization. + +To convert a model, first download the Whisper PyTorch checkpoint and convert +the weights to the MLX format. For example, to convert the `tiny` model use: ``` python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny @@ -34,13 +39,8 @@ To generate a 4-bit quantized model, use `-q`. For a full list of options: python convert.py --help ``` -By default, the conversion script will make the directory `mlx_models/tiny` and save -the converted `weights.npz` and `config.json` there. - -> [!TIP] -> Alternatively, you can also download a few converted checkpoints from the -> [MLX Community](https://huggingface.co/mlx-community) organization on Hugging -> Face and skip the conversion step. +By default, the conversion script will make the directory `mlx_models/tiny` +and save the converted `weights.npz` and `config.json` there. ### Run @@ -52,6 +52,16 @@ import whisper text = whisper.transcribe(speech_file)["text"] ``` +Choose the model by setting `hf_path_or_repo`. For example: + +```python +result = whisper.transcribe(speech_file, hf_path_or_repo="models/large") +``` + +This will load the model contained in `models/large`. The `hf_path_or_repo` +can also point to an MLX-style Whisper model on the Hugging Face Hub. In this +case, the model will be automatically downloaded. + The `transcribe` function also supports word-level timestamps. You can generate these with: diff --git a/whisper/test.py b/whisper/test.py index 36ad6f1c..835fc179 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -189,7 +189,7 @@ class TestWhisper(unittest.TestCase): def test_transcribe(self): result = whisper.transcribe( - TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False + TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual( result["text"], @@ -210,7 +210,7 @@ class TestWhisper(unittest.TestCase): return result = whisper.transcribe( - audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False + audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False ) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") @@ -313,9 +313,8 @@ class TestWhisper(unittest.TestCase): def test_transcribe_word_level_timestamps_confidence_scores(self): result = whisper.transcribe( - # TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False TEST_AUDIO, - model_path=MLX_FP16_MODEL_PATH, + path_or_hf_repo=MLX_FP16_MODEL_PATH, word_timestamps=True, ) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 05ff1fd1..43f07802 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -62,7 +62,7 @@ class ModelHolder: def transcribe( audio: Union[str, np.ndarray, mx.array], *, - path_or_hf_repo: str = "mlx_models", + path_or_hf_repo: str = "mlx-community/whisper-tiny", verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4,