loading empty list is ok when strict = false (#1834)

This commit is contained in:
Awni Hannun 2025-02-05 16:19:27 -08:00 committed by GitHub
parent fe5987b81d
commit ca305afdbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View File

@ -192,7 +192,8 @@ class Module(dict):
f"shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
if len(weights) != 0:
self.update(tree_unflatten(weights))
return self
def save_weights(self, file: str):

View File

@ -167,6 +167,9 @@ class TestBase(mlx_tests.MLXTestCase):
]
)
# Empty weights is ok if strict is false
m.load_weights([], strict=False)
def test_module_state(self):
m = nn.Linear(10, 1)
m.state["hello"] = "world"