mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Report number of missing parameters (#2264)
* chore: inform * chore: format --------- Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
This commit is contained in:
parent
7ebb2e0193
commit
004c1d8ef2
@ -174,11 +174,15 @@ class Module(dict):
|
||||
new_weights = dict(weights)
|
||||
curr_weights = dict(tree_flatten(self.parameters()))
|
||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||
extras = " ".join(extras)
|
||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
||||
num_extra = len(extras)
|
||||
extras = ",\n".join(sorted(extras))
|
||||
raise ValueError(
|
||||
f"Received {num_extra} parameters not in model: \n{extras}."
|
||||
)
|
||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||
missing = " ".join(missing)
|
||||
raise ValueError(f"Missing parameters: {missing}.")
|
||||
num_missing = len(missing)
|
||||
missing = ",\n".join(sorted(missing))
|
||||
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
|
||||
for k, v in curr_weights.items():
|
||||
v_new = new_weights[k]
|
||||
if not isinstance(v_new, mx.array):
|
||||
|
Loading…
Reference in New Issue
Block a user