mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
Fix model
This commit is contained in:
parent
40c7ce8048
commit
197fd6aad8
File diff suppressed because it is too large
Load Diff
@ -719,8 +719,6 @@ def load_model(
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
@ -1053,8 +1051,6 @@ def convert(
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user