From cfb6a244ea39006febbc4f551518b53740f46b69 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 1 Jul 2025 21:27:23 -0700 Subject: [PATCH] allow parameters to be deleted (#2325) --- python/mlx/nn/layers/base.py | 6 ++++++ python/tests/test_nn.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 4a548c80d..e99943834 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -114,6 +114,12 @@ class Module(dict): super(Module, self).__setattr__(key, val) self.pop(key, None) + def __delattr__(self, name): + if (val := self.get(name, None)) is not None: + del self[name] + else: + super().__delattr__(name) + def load_weights( self, file_or_weights: Union[str, List[Tuple[str, mx.array]]], diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 53bcb3141..ae3fae4da 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase): m = MyModel() m.update_modules(m.leaf_modules()) + def test_parameter_deletion(self): + m = nn.Linear(32, 32) + del m.weight + self.assertFalse(hasattr(m, "weight")) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):