mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 05:04:37 +08:00
Fix model
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -719,8 +719,6 @@ def load_model(
|
|||||||
weights = {}
|
weights = {}
|
||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
if "lm_head.weight" not in weights:
|
|
||||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
|
||||||
for k in weights.keys():
|
for k in weights.keys():
|
||||||
if "conv1d.weight" in k:
|
if "conv1d.weight" in k:
|
||||||
weights[k] = weights[k].transpose(0, 2, 1)
|
weights[k] = weights[k].transpose(0, 2, 1)
|
||||||
@@ -1053,8 +1051,6 @@ def convert(
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
if "lm_head.weight" not in weights:
|
|
||||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
|
||||||
for k in weights.keys():
|
for k in weights.keys():
|
||||||
if "conv1d.weight" in k:
|
if "conv1d.weight" in k:
|
||||||
weights[k] = weights[k].transpose(0, 2, 1)
|
weights[k] = weights[k].transpose(0, 2, 1)
|
||||||
|
Reference in New Issue
Block a user