mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Use sanitize()
This commit is contained in:
parent
197fd6aad8
commit
72269c306c
@ -719,9 +719,6 @@ def load_model(
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
|
||||
model_class, model_args_class = get_model_classes(config=config)
|
||||
|
||||
@ -1051,9 +1048,6 @@ def convert(
|
||||
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
for k in weights.keys():
|
||||
if "conv1d.weight" in k:
|
||||
weights[k] = weights[k].transpose(0, 2, 1)
|
||||
|
||||
dtype = getattr(mx, dtype)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
|
Loading…
Reference in New Issue
Block a user