diff --git a/mixtral/convert.py b/mixtral/convert.py index a1a423d0..e67f4453 100644 --- a/mixtral/convert.py +++ b/mixtral/convert.py @@ -16,7 +16,7 @@ if __name__ == "__main__": ) args = parser.parse_args() model_path = Path(args.model_path) - state = torch.load(str(model_path / "consolidated.00.pt")) + state = torch.load(str(model_path / "consolidated.00.pth")) np.savez( str(model_path / "weights.npz"), **{k: v.to(torch.float16).numpy() for k, v in state.items()},