Fix model

This commit is contained in:
Shunta Saito 2025-02-13 18:57:00 +09:00
parent 40c7ce8048
commit 197fd6aad8
2 changed files with 125 additions and 1952 deletions

File diff suppressed because it is too large Load Diff

View File

@ -719,8 +719,6 @@ def load_model(
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)
@ -1053,8 +1051,6 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)