Update convert.py

Docs are right, however, the code has a typo.
This commit is contained in:
Merrick Christensen
2023-12-12 14:33:33 -07:00
committed by GitHub
parent 9a02dce35c
commit 2206e8f7d9

View File

@@ -16,7 +16,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pt"))
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},