mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
Use config.json in llama (#159)
* Use config.json in llama * Fix pop * Fix convert * Typo
This commit is contained in:
parent
27c0a8c002
commit
ce30cc3d8f
@ -46,7 +46,9 @@ def llama(model_path):
|
|||||||
|
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
weights[k] = unshard(k, v)
|
weights[k] = unshard(k, v)
|
||||||
return weights, None
|
with open(model_path / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
return weights, params
|
||||||
|
|
||||||
|
|
||||||
def tiny_llama(model_path):
|
def tiny_llama(model_path):
|
||||||
@ -123,7 +125,7 @@ if __name__ == "__main__":
|
|||||||
help=(
|
help=(
|
||||||
"Name of the model to convert. Use 'llama' for models in the "
|
"Name of the model to convert. Use 'llama' for models in the "
|
||||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||||
"Coda Llama, and Llama chat."
|
"Code Llama, and Llama chat."
|
||||||
),
|
),
|
||||||
choices=["tiny_llama", "llama"],
|
choices=["tiny_llama", "llama"],
|
||||||
default="llama",
|
default="llama",
|
||||||
@ -133,7 +135,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model_path = Path(args.model_path)
|
model_path = Path(args.model_path)
|
||||||
weights, params = globals()[args.model_name](model_path)
|
weights, params = globals()[args.model_name](model_path)
|
||||||
|
params["model_type"] = "llama"
|
||||||
np.savez(str(model_path / "weights.npz"), **weights)
|
np.savez(str(model_path / "weights.npz"), **weights)
|
||||||
if params is not None:
|
with open(model_path / "config.json", "w") as fid:
|
||||||
with open(model_path / "params.json", "w") as fid:
|
json.dump(params, fid, indent=4)
|
||||||
json.dump(params, fid, indent=4)
|
|
||||||
|
@ -329,8 +329,9 @@ def few_shot_generate(args):
|
|||||||
def load_model(model_path):
|
def load_model(model_path):
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
with open(model_path / "params.json", "r") as f:
|
with open(model_path / "config.json", "r") as f:
|
||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
n_heads = config["n_heads"]
|
n_heads = config["n_heads"]
|
||||||
if "n_kv_heads" not in config:
|
if "n_kv_heads" not in config:
|
||||||
config["n_kv_heads"] = n_heads
|
config["n_kv_heads"] = n_heads
|
||||||
|
Loading…
Reference in New Issue
Block a user