fix tests and add setter attributes

This commit is contained in:
Awni Hannun 2024-03-05 09:36:45 -08:00
parent 8918a437bb
commit 1368bce280
2 changed files with 6 additions and 3 deletions

View File

@ -137,7 +137,10 @@ class Module(dict):
super(Module, self).__getattribute__(key)
def __setattr__(self, key: str, val: Any):
if isinstance(val, (mx.array, dict, list, tuple)):
# Allow setter properties to pass through to base class
prop = vars(self.__class__).get(key, None)
is_prop = isinstance(prop, property) and prop.fset is not None
if not is_prop and isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)

View File

@ -66,8 +66,8 @@ def checkpoint(module: Module):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
class _(type(module)):
def __call__(self, *arg, **kwarg):
return wrapped_checkpointed_fn(*arg, **kwarg)
def __call__(self, *args, **kwargs):
return wrapped_checkpointed_fn(*args, **kwargs)
module.__class__ = _
return module