diff --git a/llms/llama/convert.py b/llms/llama/convert.py index ecd75c6f..69fe1af8 100644 --- a/llms/llama/convert.py +++ b/llms/llama/convert.py @@ -46,7 +46,9 @@ def llama(model_path): for k, v in weights.items(): weights[k] = unshard(k, v) - return weights, None + with open(model_path / "params.json", "r") as f: + params = json.loads(f.read()) + return weights, params def tiny_llama(model_path): @@ -123,7 +125,7 @@ if __name__ == "__main__": help=( "Name of the model to convert. Use 'llama' for models in the " "Llama family distributed by Meta including Llama 1, Llama 2, " - "Coda Llama, and Llama chat." + "Code Llama, and Llama chat." ), choices=["tiny_llama", "llama"], default="llama", @@ -133,7 +135,7 @@ if __name__ == "__main__": model_path = Path(args.model_path) weights, params = globals()[args.model_name](model_path) + params["model_type"] = "llama" np.savez(str(model_path / "weights.npz"), **weights) - if params is not None: - with open(model_path / "params.json", "w") as fid: - json.dump(params, fid, indent=4) + with open(model_path / "config.json", "w") as fid: + json.dump(params, fid, indent=4) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 3d13350b..293f7210 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -329,8 +329,9 @@ def few_shot_generate(args): def load_model(model_path): model_path = Path(model_path) weights = mx.load(str(model_path / "weights.npz")) - with open(model_path / "params.json", "r") as f: + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) + config.pop("model_type", None) n_heads = config["n_heads"] if "n_kv_heads" not in config: config["n_kv_heads"] = n_heads