diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index e493c8542..f141cfc0f 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -192,7 +192,8 @@ class Module(dict): f"shape {v_new.shape} for parameter {k}" ) - self.update(tree_unflatten(weights)) + if len(weights) != 0: + self.update(tree_unflatten(weights)) return self def save_weights(self, file: str): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9d632b488..9cfa25dae 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -167,6 +167,9 @@ class TestBase(mlx_tests.MLXTestCase): ] ) + # Empty weights is ok if strict is false + m.load_weights([], strict=False) + def test_module_state(self): m = nn.Linear(10, 1) m.state["hello"] = "world"