Fix module update in strict mode (#2321)

* fix module update in strict mode

* allow GELU to be pickled
This commit is contained in:
Awni Hannun
2025-06-29 11:12:29 -07:00
committed by GitHub
parent 772f471ff2
commit 33bf1a244b
3 changed files with 22 additions and 13 deletions

View File

@@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase):
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
# Using leaf_modules in the update should always work
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
m = MyModel()
m.update_modules(m.leaf_modules())
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):