mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Use model.safetensors with Whisper (#1399)
Some checks failed
Test / check_lint (push) Has been cancelled
Some checks failed
Test / check_lint (push) Has been cancelled
This commit is contained in:
@@ -382,7 +382,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Save weights
|
# Save weights
|
||||||
print("[INFO] Saving")
|
print("[INFO] Saving")
|
||||||
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
|
||||||
|
|
||||||
# Save config.json with model_type
|
# Save config.json with model_type
|
||||||
with open(str(mlx_path / "config.json"), "w") as f:
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
|
|||||||
@@ -26,7 +26,10 @@ def load_model(
|
|||||||
|
|
||||||
model_args = whisper.ModelDimensions(**config)
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
wf = model_path / "weights.safetensors"
|
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
|
||||||
|
wf = model_path / "model.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.safetensors"
|
||||||
if not wf.exists():
|
if not wf.exists():
|
||||||
wf = model_path / "weights.npz"
|
wf = model_path / "weights.npz"
|
||||||
weights = mx.load(str(wf))
|
weights = mx.load(str(wf))
|
||||||
|
|||||||
Reference in New Issue
Block a user