Use config.json in llama (#159)

* Use config.json in llama

* Fix pop

* Fix convert

* Typo
This commit is contained in:
Pedro Cuenca 2023-12-20 19:34:44 +01:00 committed by GitHub
parent 27c0a8c002
commit ce30cc3d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 6 deletions

View File

@ -46,7 +46,9 @@ def llama(model_path):
for k, v in weights.items():
weights[k] = unshard(k, v)
return weights, None
with open(model_path / "params.json", "r") as f:
params = json.loads(f.read())
return weights, params
def tiny_llama(model_path):
@ -123,7 +125,7 @@ if __name__ == "__main__":
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "
"Coda Llama, and Llama chat."
"Code Llama, and Llama chat."
),
choices=["tiny_llama", "llama"],
default="llama",
@ -133,7 +135,7 @@ if __name__ == "__main__":
model_path = Path(args.model_path)
weights, params = globals()[args.model_name](model_path)
params["model_type"] = "llama"
np.savez(str(model_path / "weights.npz"), **weights)
if params is not None:
with open(model_path / "params.json", "w") as fid:
with open(model_path / "config.json", "w") as fid:
json.dump(params, fid, indent=4)

View File

@ -329,8 +329,9 @@ def few_shot_generate(args):
def load_model(model_path):
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "params.json", "r") as f:
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads