mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Report number of missing parameters (#2264)
* chore: inform * chore: format --------- Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
This commit is contained in:
parent
7ebb2e0193
commit
004c1d8ef2
@ -174,11 +174,15 @@ class Module(dict):
|
|||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = dict(tree_flatten(self.parameters()))
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
extras = " ".join(extras)
|
num_extra = len(extras)
|
||||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
extras = ",\n".join(sorted(extras))
|
||||||
|
raise ValueError(
|
||||||
|
f"Received {num_extra} parameters not in model: \n{extras}."
|
||||||
|
)
|
||||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||||
missing = " ".join(missing)
|
num_missing = len(missing)
|
||||||
raise ValueError(f"Missing parameters: {missing}.")
|
missing = ",\n".join(sorted(missing))
|
||||||
|
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
|
||||||
for k, v in curr_weights.items():
|
for k, v in curr_weights.items():
|
||||||
v_new = new_weights[k]
|
v_new = new_weights[k]
|
||||||
if not isinstance(v_new, mx.array):
|
if not isinstance(v_new, mx.array):
|
||||||
|
Loading…
Reference in New Issue
Block a user