mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-10 11:16:40 +08:00
Fix whipser conversion for safetensors models (#935)
* fix whipser conversion for safetensor only. error in mlx lm for existing paths * fix tests
This commit is contained in:
parent
33905447f9
commit
95840f32e2
@ -660,6 +660,16 @@ def convert(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
dequantize: bool = False,
|
dequantize: bool = False,
|
||||||
):
|
):
|
||||||
|
# Check the save path is empty
|
||||||
|
if isinstance(mlx_path, str):
|
||||||
|
mlx_path = Path(mlx_path)
|
||||||
|
|
||||||
|
if mlx_path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot save to the path {mlx_path} as it already exists."
|
||||||
|
" Please delete the file/directory or specify a new path to save to."
|
||||||
|
)
|
||||||
|
|
||||||
print("[INFO] Loading")
|
print("[INFO] Loading")
|
||||||
model_path = get_model_path(hf_path, revision=revision)
|
model_path = get_model_path(hf_path, revision=revision)
|
||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
@ -681,9 +691,6 @@ def convert(
|
|||||||
model = dequantize_model(model)
|
model = dequantize_model(model)
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
if isinstance(mlx_path, str):
|
|
||||||
mlx_path = Path(mlx_path)
|
|
||||||
|
|
||||||
del model
|
del model
|
||||||
save_weights(mlx_path, weights, donate_weights=True)
|
save_weights(mlx_path, weights, donate_weights=True)
|
||||||
|
|
||||||
|
@ -82,6 +82,7 @@ class TestUtils(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
|
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
|
||||||
|
|
||||||
# Check model weights have right type
|
# Check model weights have right type
|
||||||
|
mlx_path = os.path.join(self.test_dir, "mlx_model_bf16")
|
||||||
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
|
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
|
||||||
model, _ = utils.load(mlx_path)
|
model, _ = utils.load(mlx_path)
|
||||||
|
|
||||||
|
@ -163,7 +163,12 @@ def load_torch_weights_and_config(
|
|||||||
|
|
||||||
name_or_path = snapshot_download(
|
name_or_path = snapshot_download(
|
||||||
repo_id=name_or_path,
|
repo_id=name_or_path,
|
||||||
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
|
allow_patterns=[
|
||||||
|
"*.json",
|
||||||
|
"pytorch_model.bin",
|
||||||
|
"model.safetensors",
|
||||||
|
"*.txt",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -176,10 +181,11 @@ def load_torch_weights_and_config(
|
|||||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||||
else:
|
else:
|
||||||
name_or_path = Path(name_or_path)
|
name_or_path = Path(name_or_path)
|
||||||
weights = torch.load(
|
pt_path = name_or_path / "pytorch_model.bin"
|
||||||
name_or_path / "pytorch_model.bin",
|
if pt_path.is_file():
|
||||||
map_location="cpu",
|
weights = torch.load(pt_path, map_location="cpu")
|
||||||
)
|
else:
|
||||||
|
weights = mx.load(str(name_or_path / "model.safetensors"))
|
||||||
with open(name_or_path / "config.json", "r") as fp:
|
with open(name_or_path / "config.json", "r") as fp:
|
||||||
config = json.load(fp)
|
config = json.load(fp)
|
||||||
weights, config = hf_to_pt(weights, config)
|
weights, config = hf_to_pt(weights, config)
|
||||||
@ -230,7 +236,9 @@ def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
|
|||||||
key = key.replace("mlp.2", "mlp2")
|
key = key.replace("mlp.2", "mlp2")
|
||||||
if "conv" in key and value.ndim == 3:
|
if "conv" in key and value.ndim == 3:
|
||||||
value = value.swapaxes(1, 2)
|
value = value.swapaxes(1, 2)
|
||||||
return key, mx.array(value.detach()).astype(dtype)
|
if isinstance(value, torch.Tensor):
|
||||||
|
value = mx.array(value.detach())
|
||||||
|
return key, value.astype(dtype)
|
||||||
|
|
||||||
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
|
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
|
||||||
weights.pop("encoder.positional_embedding", None)
|
weights.pop("encoder.positional_embedding", None)
|
||||||
@ -262,12 +270,16 @@ This model was converted to MLX format from [`{torch_name_or_path}`]().
|
|||||||
|
|
||||||
## Use with mlx
|
## Use with mlx
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ml-explore/mlx-examples.git
|
pip install mlx-whisper
|
||||||
cd mlx-examples/whisper/
|
```
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
>> import whisper
|
```python
|
||||||
>> whisper.transcribe("FILE_NAME")
|
import mlx_whisper
|
||||||
|
|
||||||
|
result = mlx_whisper.transcribe(
|
||||||
|
"FILE_NAME",
|
||||||
|
path_or_hf_repo={repo_id},
|
||||||
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
card = ModelCard(text)
|
card = ModelCard(text)
|
||||||
|
Loading…
Reference in New Issue
Block a user