pick up preivously set attributes (#905)

This commit is contained in:
Awni Hannun 2024-03-26 11:19:59 -07:00 committed by GitHub
parent 9948eddf11
commit 570f2bf29e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View File

@ -138,6 +138,11 @@ class Module(dict):
def __setattr__(self, key: str, val: Any):
if isinstance(val, (mx.array, dict, list, tuple)):
# If attribute was previously set but not in the
# dictionary, delete it so we pick it up in future
# calls to __getattr__
if hasattr(self, key) and key not in self:
delattr(self, key)
self[key] = val
else:
super(Module, self).__setattr__(key, val)

View File

@ -54,6 +54,21 @@ class TestBase(mlx_tests.MLXTestCase):
m.apply_to_modules(assert_training)
def test_module_attributes(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.val = None
self.initialize()
def initialize(self):
self.val = mx.array(1.0)
model = Model()
self.assertTrue(mx.array_equal(model.val, mx.array(1.0)))
def test_model_with_dict(self):
class DictModule(nn.Module):
def __init__(self):