From 5adf185f861383fed84d2c0177397cf152970176 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 20 Jun 2025 17:19:46 -0700 Subject: [PATCH] Fix `update_modules()` when providing a subset (#2308) --- python/mlx/nn/layers/base.py | 2 +- python/tests/test_nn.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index af639dc4e..ce2ccb209 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -413,7 +413,7 @@ class Module(dict): f'Module does not have sub-module named "{k}".' ) elif isinstance(modules, list): - for i in range(len(dst)): + for i in range(len(modules)): current_value = dst[i] new_value = modules[i] if self.is_module(current_value) and self.is_module(new_value): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 10bbe821e..7753224b3 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): m = m.update_modules({"list": ["hi"]}) + # Allow updating a strict subset + m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) + self.assertEqual(m.layers[1].weight.shape, (4, 3)) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):