mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 16:21:14 +08:00
Fix module update in strict mode (#2321)
* fix module update in strict mode * allow GELU to be pickled
This commit is contained in:
parent
772f471ff2
commit
33bf1a244b
@ -554,20 +554,19 @@ class GELU(Module):
|
||||
|
||||
def __init__(self, approx="none"):
|
||||
super().__init__()
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise" or approx == "tanh":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
self._approx = approx
|
||||
allowed = ["none", "precise", "tanh", "fast"]
|
||||
if approx not in allowed:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||
f"The approximation should be in {allowed} but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self._act(x)
|
||||
if self._approx == "none":
|
||||
return gelu(x)
|
||||
elif self._approx in ["precise", "tanh"]:
|
||||
return gelu_approx(x)
|
||||
return gelu_fast_approx(x)
|
||||
|
||||
|
||||
@_make_activation_module(tanh)
|
||||
|
@ -404,7 +404,7 @@ class Module(dict):
|
||||
dst[k] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
elif strict and new_value != {}:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
@ -420,7 +420,7 @@ class Module(dict):
|
||||
dst[i] = new_value
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif strict:
|
||||
elif strict and new_value != {}:
|
||||
raise ValueError(
|
||||
f"Received invalid type: {type(new_value).__name__}."
|
||||
)
|
||||
|
@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||
|
||||
# Using leaf_modules in the update should always work
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)]
|
||||
self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0}
|
||||
|
||||
m = MyModel()
|
||||
m.update_modules(m.leaf_modules())
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user