mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
Use sanitize()
This commit is contained in:
parent
197fd6aad8
commit
72269c306c
@ -719,9 +719,6 @@ def load_model(
|
|||||||
weights = {}
|
weights = {}
|
||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
for k in weights.keys():
|
|
||||||
if "conv1d.weight" in k:
|
|
||||||
weights[k] = weights[k].transpose(0, 2, 1)
|
|
||||||
|
|
||||||
model_class, model_args_class = get_model_classes(config=config)
|
model_class, model_args_class = get_model_classes(config=config)
|
||||||
|
|
||||||
@ -1051,9 +1048,6 @@ def convert(
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||||
|
|
||||||
weights = dict(tree_flatten(model.parameters()))
|
weights = dict(tree_flatten(model.parameters()))
|
||||||
for k in weights.keys():
|
|
||||||
if "conv1d.weight" in k:
|
|
||||||
weights[k] = weights[k].transpose(0, 2, 1)
|
|
||||||
|
|
||||||
dtype = getattr(mx, dtype)
|
dtype = getattr(mx, dtype)
|
||||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
Loading…
Reference in New Issue
Block a user