diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 4a548c80d..e99943834 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -114,6 +114,12 @@ class Module(dict): super(Module, self).__setattr__(key, val) self.pop(key, None) + def __delattr__(self, name): + if (val := self.get(name, None)) is not None: + del self[name] + else: + super().__delattr__(name) + def load_weights( self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 53bcb3141..ae3fae4da 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase): m = MyModel() m.update_modules(m.leaf_modules()) + def test_parameter_deletion(self): + m = nn.Linear(32, 32) + del m.weight + self.assertFalse(hasattr(m, "weight")) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):