mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 05:04:37 +08:00
Whisper improvements (#1080)
* use safetensors in whisper * speed up decoder * version
This commit is contained in:
@@ -181,7 +181,7 @@ def load_torch_weights_and_config(
|
||||
)
|
||||
|
||||
if name_or_path.endswith(".pt"):
|
||||
checkpoint = torch.load(name_or_path, map_location="cpu")
|
||||
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
|
||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||
else:
|
||||
name_or_path = Path(name_or_path)
|
||||
@@ -387,7 +387,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Save weights
|
||||
print("[INFO] Saving")
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
|
Reference in New Issue
Block a user