From 43708f3d97d363f9fb314fef6d00a9013a84fa81 Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Tue, 10 Jun 2025 13:22:52 +0100 Subject: [PATCH] chore: inform --- python/mlx/nn/layers/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 783ef446d..89c511508 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -174,11 +174,13 @@ class Module(dict): new_weights = dict(weights) curr_weights = dict(tree_flatten(self.parameters())) if extras := (new_weights.keys() - curr_weights.keys()): - extras = " ".join(extras) - raise ValueError(f"Received parameters not in model: {extras}.") + num_extra = len(extras) + extras = ",\n".join(sorted(extras)) + raise ValueError(f"Received {num_extra} parameters not in model: \n{extras}.") if missing := (curr_weights.keys() - new_weights.keys()): - missing = " ".join(missing) - raise ValueError(f"Missing parameters: {missing}.") + num_missing = len(missing) + missing = ",\n".join(sorted(missing)) + raise ValueError(f"Missing {num_missing} parameters: \n{missing}.") for k, v in curr_weights.items(): v_new = new_weights[k] if not isinstance(v_new, mx.array):