diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 447002594d..143624fd48 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -137,7 +137,10 @@ class Module(dict): super(Module, self).__getattribute__(key) def __setattr__(self, key: str, val: Any): - if isinstance(val, (mx.array, dict, list, tuple)): + # Allow setter properties to pass through to base class + prop = vars(self.__class__).get(key, None) + is_prop = isinstance(prop, property) and prop.fset is not None + if not is_prop and isinstance(val, (mx.array, dict, list, tuple)): self[key] = val else: super(Module, self).__setattr__(key, val) diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 7035450e4a..07ab217055 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -66,8 +66,8 @@ def checkpoint(module: Module): return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) class _(type(module)): - def __call__(self, *arg, **kwarg): - return wrapped_checkpointed_fn(*arg, **kwarg) + def __call__(self, *args, **kwargs): + return wrapped_checkpointed_fn(*args, **kwargs) module.__class__ = _ return module