Whisper improvements (#1080)

* use safetensors in whisper

* speed up decoder

* version
This commit is contained in:
Awni Hannun
2024-11-01 10:52:28 -07:00
committed by GitHub
parent 85ffd2c96a
commit 8160e0c4e5
6 changed files with 85 additions and 64 deletions

View File

@@ -181,7 +181,7 @@ def load_torch_weights_and_config(
)
if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu")
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
@@ -387,7 +387,7 @@ if __name__ == "__main__":
# Save weights
print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights)
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f: