diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index de7097673..a770a5d95 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -7,6 +7,42 @@ import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten +def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn): + if is_leaf_fn(model, value_key, value): + return map_fn(value) + + elif isinstance(value, Module): + return { + k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn) + for k, v in value.items() + if filter_fn(value, k, v) + } + + elif isinstance(value, dict): + nd = {} + for k, v in v.items(): + tk = f"{value_key}.{k}" + nd[k] = ( + _unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn) + if filter_fn(model, tk, v) + else {} + ) + return nd + + elif isinstance(value, list): + nl = [] + for i, vi in enumerate(value): + tk = f"{value_key}.{i}" + nl.append( + _unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn) + if filter_fn(model, tk, vi) + else {} + ) + return nl + + raise RuntimeError("Unexpected leaf found while traversing the module") + + class Module(dict): """Base class for building neural networks with MLX. @@ -98,10 +134,13 @@ class Module(dict): if key in self: return self[key] else: - raise AttributeError(f"{type(self)!r} has no attribute {key!r}") + super(Module, self).__getattr__(key, val) def __setattr__(self, key: str, val: Any): - self[key] = val + if isinstance(val, (mx.array, dict, list, tuple)): + self[key] = val + else: + super(Module, self).__setattr__(key, val) def load_weights( self, @@ -245,31 +284,11 @@ class Module(dict): is_leaf_fn = is_leaf_fn or ( lambda m, k, v: not isinstance(v, (Module, dict, list)) ) - - def unwrap(vk, v): - if is_leaf_fn(self, vk, v): - return map_fn(v) - - if isinstance(v, Module): - return v.filter_and_map(filter_fn, map_fn, is_leaf_fn) - - if isinstance(v, dict): - nd = {} - for k, v in v.items(): - tk = f"{vk}.{k}" - nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {} - return nd - - if isinstance(v, list): - nl = [] - for i, vi in enumerate(v): - tk = f"{vk}.{i}" - nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {}) - return nl - - raise RuntimeError("Unexpected leaf found while traversing the module") - - return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)} + return { + k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn) + for k, v in self.items() + if filter_fn(self, k, v) + } def parameters(self): """Recursively return all the :class:`mlx.core.array` members of this Module