Fix update_modules() when providing a subset (#2308)

This commit is contained in:
Angelos Katharopoulos 2025-06-20 17:19:46 -07:00 committed by GitHub
parent c9a9180584
commit 5adf185f86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):

View File

@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):