diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 80764cf68..e493c8542 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -112,6 +112,7 @@ class Module(dict): self[key] = val else: super(Module, self).__setattr__(key, val) + self.pop(key, None) def load_weights( self, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 38659625f..14959eddd 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -67,6 +67,12 @@ class TestBase(mlx_tests.MLXTestCase): model = Model() self.assertTrue(mx.array_equal(model.val, mx.array(1.0))) + model.val = None + self.assertEqual(model.val, None) + + model.val = mx.array([3]) + self.assertEqual(model.val.item(), 3) + def test_model_with_dict(self): class DictModule(nn.Module): def __init__(self):