Fix module update in strict mode (#2321)

* fix module update in strict mode

* allow GELU to be pickled
This commit is contained in:
Awni Hannun 2025-06-29 11:12:29 -07:00 committed by GitHub
parent 772f471ff2
commit 33bf1a244b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 13 deletions

View File

@ -554,20 +554,19 @@ class GELU(Module):
def __init__(self, approx="none"): def __init__(self, approx="none"):
super().__init__() super().__init__()
self._approx = approx
if approx == "none": allowed = ["none", "precise", "tanh", "fast"]
self._act = gelu if approx not in allowed:
elif approx == "precise" or approx == "tanh":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError( raise ValueError(
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" f"The approximation should be in {allowed} but '{approx}' was given"
) )
def __call__(self, x): def __call__(self, x):
return self._act(x) if self._approx == "none":
return gelu(x)
elif self._approx in ["precise", "tanh"]:
return gelu_approx(x)
return gelu_fast_approx(x)
@_make_activation_module(tanh) @_make_activation_module(tanh)

View File

@ -404,7 +404,7 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict: elif strict and new_value != {}:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )
@ -420,7 +420,7 @@ class Module(dict):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict: elif strict and new_value != {}:
raise ValueError( raise ValueError(
f"Received invalid type: {type(new_value).__name__}." f"Received invalid type: {type(new_value).__name__}."
) )

View File

@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase):
m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3)) self.assertEqual(m.layers[1].weight.shape, (4, 3))
# Using leaf_modules in the update should always work
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
m = MyModel()
m.update_modules(m.leaf_modules())
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):