From 004c1d8ef2fcdc04f29acbc497c9fea7c190591a Mon Sep 17 00:00:00 2001 From: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:37:50 +0100 Subject: [PATCH] Report number of missing parameters (#2264) * chore: inform * chore: format --------- Co-authored-by: FL33TW00D --- python/mlx/nn/layers/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 783ef446d..af639dc4e 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -174,11 +174,15 @@ class Module(dict): new_weights = dict(weights) curr_weights = dict(tree_flatten(self.parameters())) if extras := (new_weights.keys() - curr_weights.keys()): - extras = " ".join(extras) - raise ValueError(f"Received parameters not in model: {extras}.") + num_extra = len(extras) + extras = ",\n".join(sorted(extras)) + raise ValueError( + f"Received {num_extra} parameters not in model: \n{extras}." + ) if missing := (curr_weights.keys() - new_weights.keys()): - missing = " ".join(missing) - raise ValueError(f"Missing parameters: {missing}.") + num_missing = len(missing) + missing = ",\n".join(sorted(missing)) + raise ValueError(f"Missing {num_missing} parameters: \n{missing}.") for k, v in curr_weights.items(): v_new = new_weights[k] if not isinstance(v_new, mx.array):