Use config.json, add model_type (#157)

* Use config.json, add model_type

* Update convert to generate config.json
This commit is contained in:
Pedro Cuenca
2023-12-20 17:39:37 +01:00
committed by GitHub
parent 06b7c45c59
commit 920bd13668
2 changed files with 11 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
import numpy as np
from pathlib import Path
import torch
@@ -22,3 +23,10 @@ if __name__ == "__main__":
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)
# Save config.json with model_type
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
config["model_type"] = "mistral"
with open(model_path / "config.json", "w") as f:
json.dump(config, f, indent=4)