diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cffa2a89..0e7f7a39 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -660,6 +660,16 @@ def convert( revision: Optional[str] = None, dequantize: bool = False, ): + # Check the save path is empty + if isinstance(mlx_path, str): + mlx_path = Path(mlx_path) + + if mlx_path.exists(): + raise ValueError( + f"Cannot save to the path {mlx_path} as it already exists." + " Please delete the file/directory or specify a new path to save to." + ) + print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) @@ -681,9 +691,6 @@ def convert( model = dequantize_model(model) weights = dict(tree_flatten(model.parameters())) - if isinstance(mlx_path, str): - mlx_path = Path(mlx_path) - del model save_weights(mlx_path, weights, donate_weights=True) diff --git a/llms/tests/test_utils.py b/llms/tests/test_utils.py index 576c2820..18cfa8c7 100644 --- a/llms/tests/test_utils.py +++ b/llms/tests/test_utils.py @@ -82,6 +82,7 @@ class TestUtils(unittest.TestCase): self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear)) # Check model weights have right type + mlx_path = os.path.join(self.test_dir, "mlx_model_bf16") utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16") model, _ = utils.load(mlx_path) diff --git a/whisper/convert.py b/whisper/convert.py index 85ce5fba..da7195e0 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -163,7 +163,12 @@ def load_torch_weights_and_config( name_or_path = snapshot_download( repo_id=name_or_path, - allow_patterns=["*.json", "pytorch_model.bin", "*.txt"], + allow_patterns=[ + "*.json", + "pytorch_model.bin", + "model.safetensors", + "*.txt", + ], ) else: raise RuntimeError( @@ -176,10 +181,11 @@ def load_torch_weights_and_config( weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) - weights = torch.load( - name_or_path / "pytorch_model.bin", - map_location="cpu", - ) + pt_path = name_or_path / "pytorch_model.bin" + if pt_path.is_file(): + weights = torch.load(pt_path, map_location="cpu") + else: + weights = mx.load(str(name_or_path / "model.safetensors")) with open(name_or_path / "config.json", "r") as fp: config = json.load(fp) weights, config = hf_to_pt(weights, config) @@ -230,7 +236,9 @@ def convert(name_or_path: str, dtype: mx.Dtype = mx.float16): key = key.replace("mlp.2", "mlp2") if "conv" in key and value.ndim == 3: value = value.swapaxes(1, 2) - return key, mx.array(value.detach()).astype(dtype) + if isinstance(value, torch.Tensor): + value = mx.array(value.detach()) + return key, value.astype(dtype) weights, config, alignment_heads = load_torch_weights_and_config(name_or_path) weights.pop("encoder.positional_embedding", None) @@ -262,12 +270,16 @@ This model was converted to MLX format from [`{torch_name_or_path}`](). ## Use with mlx ```bash -git clone https://github.com/ml-explore/mlx-examples.git -cd mlx-examples/whisper/ -pip install -r requirements.txt +pip install mlx-whisper +``` ->> import whisper ->> whisper.transcribe("FILE_NAME") +```python +import mlx_whisper + +result = mlx_whisper.transcribe( + "FILE_NAME", + path_or_hf_repo={repo_id}, +) ``` """ card = ModelCard(text)