Report number of missing parameters (#2264)

* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
This commit is contained in:
Christopher Fleetwood 2025-06-10 14:37:50 +01:00 committed by GitHub
parent 7ebb2e0193
commit 004c1d8ef2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,11 +174,15 @@ class Module(dict):
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
num_extra = len(extras)
extras = ",\n".join(sorted(extras))
raise ValueError(
f"Received {num_extra} parameters not in model: \n{extras}."
)
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
num_missing = len(missing)
missing = ",\n".join(sorted(missing))
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):