From 1368bce280c6c1192733ad28754dc2c9a2276425 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Mar 2024 09:36:45 -0800 Subject: [PATCH] fix tests and add setter attributes --- python/mlx/nn/layers/base.py | 5 ++++- python/mlx/nn/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 447002594d..143624fd48 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -137,7 +137,10 @@ class Module(dict): super(Module, self).__getattribute__(key) def __setattr__(self, key: str, val: Any): - if isinstance(val, (mx.array, dict, list, tuple)): + # Allow setter properties to pass through to base class + prop = vars(self.__class__).get(key, None) + is_prop = isinstance(prop, property) and prop.fset is not None + if not is_prop and isinstance(val, (mx.array, dict, list, tuple)): self[key] = val else: super(Module, self).__setattr__(key, val) diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 7035450e4a..07ab217055 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -66,8 +66,8 @@ def checkpoint(module: Module): return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) class _(type(module)): - def __call__(self, *arg, **kwarg): - return wrapped_checkpointed_fn(*arg, **kwarg) + def __call__(self, *args, **kwargs): + return wrapped_checkpointed_fn(*args, **kwargs) module.__class__ = _ return module