diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 783ef446d..af639dc4e 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -174,11 +174,15 @@ class Module(dict): new_weights = dict(weights) curr_weights = dict(tree_flatten(self.parameters())) if extras := (new_weights.keys() - curr_weights.keys()): - extras = " ".join(extras) - raise ValueError(f"Received parameters not in model: {extras}.") + num_extra = len(extras) + extras = ",\n".join(sorted(extras)) + raise ValueError( + f"Received {num_extra} parameters not in model: \n{extras}." + ) if missing := (curr_weights.keys() - new_weights.keys()): - missing = " ".join(missing) - raise ValueError(f"Missing parameters: {missing}.") + num_missing = len(missing) + missing = ",\n".join(sorted(missing)) + raise ValueError(f"Missing {num_missing} parameters: \n{missing}.") for k, v in curr_weights.items(): v_new = new_weights[k] if not isinstance(v_new, mx.array):