From 33bf1a244b8cd0a58c9a9363c13b46c1714221d5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 29 Jun 2025 11:12:29 -0700 Subject: [PATCH] Fix module update in strict mode (#2321) * fix module update in strict mode * allow GELU to be pickled --- python/mlx/nn/layers/activations.py | 21 ++++++++++----------- python/mlx/nn/layers/base.py | 4 ++-- python/tests/test_nn.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..21994c0e6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -546,7 +546,7 @@ class GELU(Module): See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the functional equivalents and information regarding error bounds. - + Args: approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. @@ -554,20 +554,19 @@ class GELU(Module): def __init__(self, approx="none"): super().__init__() - - if approx == "none": - self._act = gelu - elif approx == "precise" or approx == "tanh": - self._act = gelu_approx - elif approx == "fast": - self._act = gelu_fast_approx - else: + self._approx = approx + allowed = ["none", "precise", "tanh", "fast"] + if approx not in allowed: raise ValueError( - f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" + f"The approximation should be in {allowed} but '{approx}' was given" ) def __call__(self, x): - return self._act(x) + if self._approx == "none": + return gelu(x) + elif self._approx in ["precise", "tanh"]: + return gelu_approx(x) + return gelu_fast_approx(x) @_make_activation_module(tanh) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index ce2ccb209..4a548c80d 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -404,7 +404,7 @@ class Module(dict): dst[k] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) - elif strict: + elif strict and new_value != {}: raise ValueError( f"Received invalid type: {type(new_value).__name__}." ) @@ -420,7 +420,7 @@ class Module(dict): dst[i] = new_value elif isinstance(current_value, (dict, list)): apply(current_value, new_value) - elif strict: + elif strict and new_value != {}: raise ValueError( f"Received invalid type: {type(new_value).__name__}." ) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7753224b3..53bcb3141 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -264,6 +264,16 @@ class TestBase(mlx_tests.MLXTestCase): m.update_modules({"layers": [{}, nn.Linear(3, 4)]}) self.assertEqual(m.layers[1].weight.shape, (4, 3)) + # Using leaf_modules in the update should always work + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.stuff = [nn.Linear(2, 2), 0, nn.Linear(2, 2)] + self.more_stuff = {"hi": nn.Linear(2, 2), "bye": 0} + + m = MyModel() + m.update_modules(m.leaf_modules()) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):