mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-06 11:36:48 +08:00
fix tests and add setter attributes
This commit is contained in:
parent
8918a437bb
commit
1368bce280
@ -137,7 +137,10 @@ class Module(dict):
|
|||||||
super(Module, self).__getattribute__(key)
|
super(Module, self).__getattribute__(key)
|
||||||
|
|
||||||
def __setattr__(self, key: str, val: Any):
|
def __setattr__(self, key: str, val: Any):
|
||||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
# Allow setter properties to pass through to base class
|
||||||
|
prop = vars(self.__class__).get(key, None)
|
||||||
|
is_prop = isinstance(prop, property) and prop.fset is not None
|
||||||
|
if not is_prop and isinstance(val, (mx.array, dict, list, tuple)):
|
||||||
self[key] = val
|
self[key] = val
|
||||||
else:
|
else:
|
||||||
super(Module, self).__setattr__(key, val)
|
super(Module, self).__setattr__(key, val)
|
||||||
|
@ -66,8 +66,8 @@ def checkpoint(module: Module):
|
|||||||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
|
||||||
|
|
||||||
class _(type(module)):
|
class _(type(module)):
|
||||||
def __call__(self, *arg, **kwarg):
|
def __call__(self, *args, **kwargs):
|
||||||
return wrapped_checkpointed_fn(*arg, **kwarg)
|
return wrapped_checkpointed_fn(*args, **kwargs)
|
||||||
|
|
||||||
module.__class__ = _
|
module.__class__ = _
|
||||||
return module
|
return module
|
||||||
|
Loading…
Reference in New Issue
Block a user