mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Fix module update in strict mode (#2321)
* fix module update in strict mode * allow GELU to be pickled
This commit is contained in:
@@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||
|
||||
# Using leaf_modules in the update should always work
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
|
||||
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
|
||||
|
||||
m = MyModel()
|
||||
m.update_modules(m.leaf_modules())
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Reference in New Issue
Block a user