This commit is contained in:
Angelos Katharopoulos 2024-01-21 21:27:09 -08:00
parent 1e4d3c7fb2
commit f5ece1d98d
2 changed files with 44 additions and 11 deletions

View File

@ -7,7 +7,7 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
class Module(dict):
class Module:
"""Base class for building neural networks with MLX.
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
@ -58,6 +58,9 @@ class Module(dict):
def __init__(self):
"""Should be called by the subclasses of ``Module``."""
# Initialize _keys to implement __setattr__
super().__setattr__("_keys", set())
self._no_grad = set()
self._training = True
@ -81,14 +84,29 @@ class Module(dict):
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
def __setattr__(self, key: str, val: Any):
self[key] = val
if not key.startswith("_"):
self._keys.add(key)
super().__setattr__(key, val)
def __getitem__(self, key: str):
if key not in self._keys:
raise KeyError(key)
return getattr(self, key)
def __setitem__(self, key: str, val: Any):
if key not in self._keys:
raise KeyError(key)
setattr(self, key, val)
def __contains__(self, key: str):
return key in self._keys
def keys(self):
return (k for k in self._keys)
def items(self):
return ((k, self[k]) for k in self._keys)
def load_weights(
self,
@ -190,11 +208,13 @@ class Module(dict):
@staticmethod
def valid_child_filter(module, key, value):
return isinstance(value, (dict, list))
return isinstance(value, (Module, dict, list))
@staticmethod
def valid_parameter_filter(module, key, value):
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith(
"_"
)
@staticmethod
def trainable_parameter_filter(module, key, value):
@ -203,6 +223,13 @@ class Module(dict):
and key not in module._no_grad
)
@staticmethod
def non_trainable_parameter_filter(module, key, value):
return not key.startswith("_") and (
isinstance(value, (Module, dict, list))
or (isinstance(value, mx.array) and key in module._no_grad)
)
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
@ -268,6 +295,11 @@ class Module(dict):
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.trainable_parameter_filter)
def non_trainable_parameters(self):
"""Recursively return all the frozen :class:`mlx.core.array` members of
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.non_trainable_parameter_filter)
def children(self):
"""Return the direct descendants of this Module instance."""
return self.filter_and_map(

View File

@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@ -485,7 +486,7 @@ struct PyCompiledFun {
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto inputs = tree_flatten(args, false);
// Get globally enclosed arrays so we don't compile through them
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638