allow parameters to be deleted (#2325)

This commit is contained in:
Awni Hannun 2025-07-01 21:27:23 -07:00 committed by GitHub
parent 58f3860306
commit cfb6a244ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 0 deletions

View File

@ -114,6 +114,12 @@ class Module(dict):
super(Module, self).__setattr__(key, val)
self.pop(key, None)
def __delattr__(self, name):
if (val := self.get(name, None)) is not None:
del self[name]
else:
super().__delattr__(name)
def load_weights(
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],

View File

@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase):
m = MyModel()
m.update_modules(m.leaf_modules())
def test_parameter_deletion(self):
m = nn.Linear(32, 32)
del m.weight
self.assertFalse(hasattr(m, "weight"))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):