mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-05 19:06:44 +08:00
fix tests and add setter attributes
This commit is contained in:
parent
8918a437bb
commit
1368bce280
@ -137,7 +137,10 @@ class Module(dict):
|
||||
super(Module, self).__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||
# Allow setter properties to pass through to base class
|
||||
prop = vars(self.__class__).get(key, None)
|
||||
is_prop = isinstance(prop, property) and prop.fset is not None
|
||||
if not is_prop and isinstance(val, (mx.array, dict, list, tuple)):
|
||||
self[key] = val
|
||||
else:
|
||||
super(Module, self).__setattr__(key, val)
|
||||
|
@ -66,8 +66,8 @@ def checkpoint(module: Module):
|
||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||
|
||||
class _(type(module)):
|
||||
def __call__(self, *arg, **kwarg):
|
||||
return wrapped_checkpointed_fn(*arg, **kwarg)
|
||||
def __call__(self, *args, **kwargs):
|
||||
return wrapped_checkpointed_fn(*args, **kwargs)
|
||||
|
||||
module.__class__ = _
|
||||
return module
|
||||
|
Loading…
Reference in New Issue
Block a user