mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
avoid nested closures in module (#759)
This commit is contained in:
parent
776c3d226d
commit
4494970f47
@ -7,6 +7,42 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@ -98,10 +134,13 @@ class Module(dict):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
super(Module, self).__getattr__(key, val)
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||
self[key] = val
|
||||
else:
|
||||
super(Module, self).__setattr__(key, val)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@ -245,31 +284,11 @@ class Module(dict):
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
return {
|
||||
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in self.items()
|
||||
if filter_fn(self, k, v)
|
||||
}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
|
Loading…
Reference in New Issue
Block a user