diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ea12703e..ff793ee5 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -719,9 +719,6 @@ def load_model( weights = {} for wf in weight_files: weights.update(mx.load(wf)) - for k in weights.keys(): - if "conv1d.weight" in k: - weights[k] = weights[k].transpose(0, 2, 1) model_class, model_args_class = get_model_classes(config=config) @@ -1051,9 +1048,6 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - for k in weights.keys(): - if "conv1d.weight" in k: - weights[k] = weights[k].transpose(0, 2, 1) dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()}