diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index af639dc4e..ce2ccb209 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -413,7 +413,7 @@ class Module(dict): f'Module does not have sub-module named "{k}".' ) elif isinstance(modules, list): - for i in range(len(dst)): + for i in range(len(modules)): current_value = dst[i] new_value = modules[i] if self.is_module(current_value) and self.is_module(new_value): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 10bbe821e..7753224b3 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): m = m.update_modules({"list": ["hi"]}) + # Allow updating a strict subset + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) + self.assertEqual(m.layers[1].weight.shape, (4, 3)) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):