From ca305afdbe6ea5aab3cc6374fee655bd0ab68294 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Feb 2025 16:19:27 -0800 Subject: [PATCH] loading empty list is ok when strict = false (#1834) --- python/mlx/nn/layers/base.py | 3 ++- python/tests/test_nn.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index e493c8542..f141cfc0f 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -192,7 +192,8 @@ class Module(dict): f"shape {v_new.shape} for parameter {k}" ) - self.update(tree_unflatten(weights)) + if len(weights) != 0: + self.update(tree_unflatten(weights)) return self def save_weights(self, file: str): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9d632b488..9cfa25dae 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -167,6 +167,9 @@ class TestBase(mlx_tests.MLXTestCase): ] ) + # Empty weights is ok if strict is false + m.load_weights([], strict=False) + def test_module_state(self): m = nn.Linear(10, 1) m.state["hello"] = "world"