diff --git a/whisper/convert.py b/whisper/convert.py index 7369fafa..9cc8b861 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -382,7 +382,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) + mx.save_safetensors(str(mlx_path / "model.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 60766ab2..c8de5685 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - wf = model_path / "weights.safetensors" + # Prefer model.safetensors, fall back to weights.safetensors, then weights.npz + wf = model_path / "model.safetensors" + if not wf.exists(): + wf = model_path / "weights.safetensors" if not wf.exists(): wf = model_path / "weights.npz" weights = mx.load(str(wf))