From 8b30acd7eba149104f944f3b1afff0e732a48225 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 11 Sep 2024 16:30:42 -0700 Subject: [PATCH] fix module attribute set, reset, set (#1403) --- python/mlx/nn/layers/base.py | 1 + python/tests/test_nn.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 80764cf68..e493c8542 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -112,6 +112,7 @@ class Module(dict): self[key] = val else: super(Module, self).__setattr__(key, val) + self.pop(key, None) def load_weights( self, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 38659625f..14959eddd 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -67,6 +67,12 @@ class TestBase(mlx_tests.MLXTestCase): model = Model() self.assertTrue(mx.array_equal(model.val, mx.array(1.0))) + model.val = None + self.assertEqual(model.val, None) + + model.val = mx.array([3]) + self.assertEqual(model.val.item(), 3) + def test_model_with_dict(self): class DictModule(nn.Module): def __init__(self):