loading empty list is ok when strict = false (#1834)

This commit is contained in:
Awni Hannun 2025-02-05 16:19:27 -08:00 committed by GitHub
parent fe5987b81d
commit ca305afdbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View File

@ -192,7 +192,8 @@ class Module(dict):
f"shape {v_new.shape} for parameter {k}" f"shape {v_new.shape} for parameter {k}"
) )
self.update(tree_unflatten(weights)) if len(weights) != 0:
self.update(tree_unflatten(weights))
return self return self
def save_weights(self, file: str): def save_weights(self, file: str):

View File

@ -167,6 +167,9 @@ class TestBase(mlx_tests.MLXTestCase):
] ]
) )
# Empty weights is ok if strict is false
m.load_weights([], strict=False)
def test_module_state(self): def test_module_state(self):
m = nn.Linear(10, 1) m = nn.Linear(10, 1)
m.state["hello"] = "world" m.state["hello"] = "world"