mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Use config.json in llama (#159)
* Use config.json in llama * Fix pop * Fix convert * Typo
This commit is contained in:
parent
27c0a8c002
commit
ce30cc3d8f
@ -46,7 +46,9 @@ def llama(model_path):
|
||||
|
||||
for k, v in weights.items():
|
||||
weights[k] = unshard(k, v)
|
||||
return weights, None
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
return weights, params
|
||||
|
||||
|
||||
def tiny_llama(model_path):
|
||||
@ -123,7 +125,7 @@ if __name__ == "__main__":
|
||||
help=(
|
||||
"Name of the model to convert. Use 'llama' for models in the "
|
||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||
"Coda Llama, and Llama chat."
|
||||
"Code Llama, and Llama chat."
|
||||
),
|
||||
choices=["tiny_llama", "llama"],
|
||||
default="llama",
|
||||
@ -133,7 +135,7 @@ if __name__ == "__main__":
|
||||
|
||||
model_path = Path(args.model_path)
|
||||
weights, params = globals()[args.model_name](model_path)
|
||||
params["model_type"] = "llama"
|
||||
np.savez(str(model_path / "weights.npz"), **weights)
|
||||
if params is not None:
|
||||
with open(model_path / "params.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
with open(model_path / "config.json", "w") as fid:
|
||||
json.dump(params, fid, indent=4)
|
||||
|
@ -329,8 +329,9 @@ def few_shot_generate(args):
|
||||
def load_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
with open(model_path / "params.json", "r") as f:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
n_heads = config["n_heads"]
|
||||
if "n_kv_heads" not in config:
|
||||
config["n_kv_heads"] = n_heads
|
||||
|
Loading…
Reference in New Issue
Block a user