mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
chore: inform
This commit is contained in:
parent
9ce77798b1
commit
43708f3d97
@ -174,11 +174,13 @@ class Module(dict):
|
|||||||
new_weights = dict(weights)
|
new_weights = dict(weights)
|
||||||
curr_weights = dict(tree_flatten(self.parameters()))
|
curr_weights = dict(tree_flatten(self.parameters()))
|
||||||
if extras := (new_weights.keys() - curr_weights.keys()):
|
if extras := (new_weights.keys() - curr_weights.keys()):
|
||||||
extras = " ".join(extras)
|
num_extra = len(extras)
|
||||||
raise ValueError(f"Received parameters not in model: {extras}.")
|
extras = ",\n".join(sorted(extras))
|
||||||
|
raise ValueError(f"Received {num_extra} parameters not in model: \n{extras}.")
|
||||||
if missing := (curr_weights.keys() - new_weights.keys()):
|
if missing := (curr_weights.keys() - new_weights.keys()):
|
||||||
missing = " ".join(missing)
|
num_missing = len(missing)
|
||||||
raise ValueError(f"Missing parameters: {missing}.")
|
missing = ",\n".join(sorted(missing))
|
||||||
|
raise ValueError(f"Missing {num_missing} parameters: \n{missing}.")
|
||||||
for k, v in curr_weights.items():
|
for k, v in curr_weights.items():
|
||||||
v_new = new_weights[k]
|
v_new = new_weights[k]
|
||||||
if not isinstance(v_new, mx.array):
|
if not isinstance(v_new, mx.array):
|
||||||
|
Loading…
Reference in New Issue
Block a user