Use config.json in llama (#159)

* Use config.json in llama

* Fix pop

* Fix convert

* Typo
This commit is contained in:
Pedro Cuenca
2023-12-20 19:34:44 +01:00
committed by GitHub
parent 27c0a8c002
commit ce30cc3d8f
2 changed files with 9 additions and 6 deletions

View File

@@ -329,8 +329,9 @@ def few_shot_generate(args):
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads