diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 3da1993ec..6d1978103 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -7,7 +7,7 @@ import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten -class Module(dict): +class Module: """Base class for building neural networks with MLX. All the layers provided in :mod:`mlx.nn.layers` subclass this class and @@ -58,6 +58,9 @@ class Module(dict): def __init__(self): """Should be called by the subclasses of ``Module``.""" + # Initialize _keys to implement __setattr__ + super().__setattr__("_keys", set()) + self._no_grad = set() self._training = True @@ -81,14 +84,29 @@ class Module(dict): return value - def __getattr__(self, key: str): - if key in self: - return self[key] - else: - raise AttributeError(f"{type(self)!r} has no attribute {key!r}") - def __setattr__(self, key: str, val: Any): - self[key] = val + if not key.startswith("_"): + self._keys.add(key) + super().__setattr__(key, val) + + def __getitem__(self, key: str): + if key not in self._keys: + raise KeyError(key) + return getattr(self, key) + + def __setitem__(self, key: str, val: Any): + if key not in self._keys: + raise KeyError(key) + setattr(self, key, val) + + def __contains__(self, key: str): + return key in self._keys + + def keys(self): + return (k for k in self._keys) + + def items(self): + return ((k, self[k]) for k in self._keys) def load_weights( self, @@ -190,11 +208,13 @@ class Module(dict): @staticmethod def valid_child_filter(module, key, value): - return isinstance(value, (dict, list)) + return isinstance(value, (Module, dict, list)) @staticmethod def valid_parameter_filter(module, key, value): - return isinstance(value, (dict, list, mx.array)) and not key.startswith("_") + return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith( + "_" + ) @staticmethod def trainable_parameter_filter(module, key, value): @@ -203,6 +223,13 @@ class Module(dict): and key not in module._no_grad ) + @staticmethod + def non_trainable_parameter_filter(module, key, value): + return not key.startswith("_") and ( + isinstance(value, (Module, dict, list)) + or (isinstance(value, mx.array) and key in module._no_grad) + ) + def filter_and_map( self, filter_fn: Callable[["mlx.nn.Module", str, Any], bool], @@ -268,6 +295,11 @@ class Module(dict): this Module as a dict of dicts and lists.""" return self.filter_and_map(self.trainable_parameter_filter) + def non_trainable_parameters(self): + """Recursively return all the frozen :class:`mlx.core.array` members of + this Module as a dict of dicts and lists.""" + return self.filter_and_map(self.non_trainable_parameter_filter) + def children(self): """Return the direct descendants of this Module instance.""" return self.filter_and_map( diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a7233a468..3d07c1ee7 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,4 +1,5 @@ // Copyright © 2023-2024 Apple Inc. + #include #include #include @@ -485,7 +486,7 @@ struct PyCompiledFun { }; // Inputs must be array or tree of arrays - auto inputs = tree_flatten(args, true); + auto inputs = tree_flatten(args, false); // Get globally enclosed arrays so we don't compile through them // c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638