mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Fix whipser conversion for safetensors models (#935)
* fix whipser conversion for safetensor only. error in mlx lm for existing paths * fix tests
This commit is contained in:
@@ -163,7 +163,12 @@ def load_torch_weights_and_config(
|
||||
|
||||
name_or_path = snapshot_download(
|
||||
repo_id=name_or_path,
|
||||
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"pytorch_model.bin",
|
||||
"model.safetensors",
|
||||
"*.txt",
|
||||
],
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@@ -176,10 +181,11 @@ def load_torch_weights_and_config(
|
||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||
else:
|
||||
name_or_path = Path(name_or_path)
|
||||
weights = torch.load(
|
||||
name_or_path / "pytorch_model.bin",
|
||||
map_location="cpu",
|
||||
)
|
||||
pt_path = name_or_path / "pytorch_model.bin"
|
||||
if pt_path.is_file():
|
||||
weights = torch.load(pt_path, map_location="cpu")
|
||||
else:
|
||||
weights = mx.load(str(name_or_path / "model.safetensors"))
|
||||
with open(name_or_path / "config.json", "r") as fp:
|
||||
config = json.load(fp)
|
||||
weights, config = hf_to_pt(weights, config)
|
||||
@@ -230,7 +236,9 @@ def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
|
||||
key = key.replace("mlp.2", "mlp2")
|
||||
if "conv" in key and value.ndim == 3:
|
||||
value = value.swapaxes(1, 2)
|
||||
return key, mx.array(value.detach()).astype(dtype)
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = mx.array(value.detach())
|
||||
return key, value.astype(dtype)
|
||||
|
||||
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
|
||||
weights.pop("encoder.positional_embedding", None)
|
||||
@@ -262,12 +270,16 @@ This model was converted to MLX format from [`{torch_name_or_path}`]().
|
||||
|
||||
## Use with mlx
|
||||
```bash
|
||||
git clone https://github.com/ml-explore/mlx-examples.git
|
||||
cd mlx-examples/whisper/
|
||||
pip install -r requirements.txt
|
||||
pip install mlx-whisper
|
||||
```
|
||||
|
||||
>> import whisper
|
||||
>> whisper.transcribe("FILE_NAME")
|
||||
```python
|
||||
import mlx_whisper
|
||||
|
||||
result = mlx_whisper.transcribe(
|
||||
"FILE_NAME",
|
||||
path_or_hf_repo={repo_id},
|
||||
)
|
||||
```
|
||||
"""
|
||||
card = ModelCard(text)
|
||||
|
Reference in New Issue
Block a user