mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Fix update_modules()
when providing a subset (#2308)
This commit is contained in:
parent
c9a9180584
commit
5adf185f86
@ -413,7 +413,7 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(modules)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
|
@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
|
# Allow updating a strict subset
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||||
|
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user