mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 03:31:13 +08:00
allow parameters to be deleted (#2325)
This commit is contained in:
parent
58f3860306
commit
cfb6a244ea
@ -114,6 +114,12 @@ class Module(dict):
|
||||
super(Module, self).__setattr__(key, val)
|
||||
self.pop(key, None)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if (val := self.get(name, None)) is not None:
|
||||
del self[name]
|
||||
else:
|
||||
super().__delattr__(name)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
|
@ -274,6 +274,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
m = MyModel()
|
||||
m.update_modules(m.leaf_modules())
|
||||
|
||||
def test_parameter_deletion(self):
|
||||
m = nn.Linear(32, 32)
|
||||
del m.weight
|
||||
self.assertFalse(hasattr(m, "weight"))
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user