Use model.safetensors with Whisper (#1399)
Some checks failed
Test / check_lint (push) Has been cancelled

This commit is contained in:
Anthony
2025-12-15 15:17:08 +01:00
committed by GitHub
parent 7ddca42f4d
commit e52c128d11
2 changed files with 5 additions and 2 deletions

View File

@@ -382,7 +382,7 @@ if __name__ == "__main__":
# Save weights # Save weights
print("[INFO] Saving") print("[INFO] Saving")
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
# Save config.json with model_type # Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f: with open(str(mlx_path / "config.json"), "w") as f:

View File

@@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config) model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors" # Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
wf = model_path / "model.safetensors"
if not wf.exists():
wf = model_path / "weights.safetensors"
if not wf.exists(): if not wf.exists():
wf = model_path / "weights.npz" wf = model_path / "weights.npz"
weights = mx.load(str(wf)) weights = mx.load(str(wf))