avoid nested closures in module (#759)

This commit is contained in:
Awni Hannun 2024-02-29 09:39:52 -08:00 committed by GitHub
parent 776c3d226d
commit 4494970f47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,42 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in v.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@ -98,10 +134,13 @@ class Module(dict):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
super(Module, self).__getattr__(key, val)
def __setattr__(self, key: str, val: Any):
self[key] = val
if isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)
def load_weights(
self,
@ -245,31 +284,11 @@ class Module(dict):
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
return {
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in self.items()
if filter_fn(self, k, v)
}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module