From 570f2bf29eb1eee243f1f8de7e096f7fec783f91 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Mar 2024 11:19:59 -0700 Subject: [PATCH] pick up preivously set attributes (#905) --- python/mlx/nn/layers/base.py | 5 +++++ python/tests/test_nn.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 51c1f7af9..9bd336690 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -138,6 +138,11 @@ class Module(dict): def __setattr__(self, key: str, val: Any): if isinstance(val, (mx.array, dict, list, tuple)): + # If attribute was previously set but not in the + # dictionary, delete it so we pick it up in future + # calls to __getattr__ + if hasattr(self, key) and key not in self: + delattr(self, key) self[key] = val else: super(Module, self).__setattr__(key, val) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e8abb2227..9225a97d9 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -54,6 +54,21 @@ class TestBase(mlx_tests.MLXTestCase): m.apply_to_modules(assert_training) + def test_module_attributes(self): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.val = None + self.initialize() + + def initialize(self): + self.val = mx.array(1.0) + + model = Model() + self.assertTrue(mx.array_equal(model.val, mx.array(1.0))) + def test_model_with_dict(self): class DictModule(nn.Module): def __init__(self):