Use sanitize()

This commit is contained in:
Shunta Saito 2025-02-13 19:45:01 +09:00
parent 197fd6aad8
commit 72269c306c

View File

@ -719,9 +719,6 @@ def load_model(
weights = {} weights = {}
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)
model_class, model_args_class = get_model_classes(config=config) model_class, model_args_class = get_model_classes(config=config)
@ -1051,9 +1048,6 @@ def convert(
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters())) weights = dict(tree_flatten(model.parameters()))
for k in weights.keys():
if "conv1d.weight" in k:
weights[k] = weights[k].transpose(0, 2, 1)
dtype = getattr(mx, dtype) dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()} weights = {k: v.astype(dtype) for k, v in weights.items()}