mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
avoid nested closures in module (#759)
This commit is contained in:
parent
776c3d226d
commit
4494970f47
@ -7,6 +7,42 @@ import mlx.core as mx
|
|||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||||
|
if is_leaf_fn(model, value_key, value):
|
||||||
|
return map_fn(value)
|
||||||
|
|
||||||
|
elif isinstance(value, Module):
|
||||||
|
return {
|
||||||
|
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||||
|
for k, v in value.items()
|
||||||
|
if filter_fn(value, k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
nd = {}
|
||||||
|
for k, v in v.items():
|
||||||
|
tk = f"{value_key}.{k}"
|
||||||
|
nd[k] = (
|
||||||
|
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||||
|
if filter_fn(model, tk, v)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
return nd
|
||||||
|
|
||||||
|
elif isinstance(value, list):
|
||||||
|
nl = []
|
||||||
|
for i, vi in enumerate(value):
|
||||||
|
tk = f"{value_key}.{i}"
|
||||||
|
nl.append(
|
||||||
|
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||||
|
if filter_fn(model, tk, vi)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
return nl
|
||||||
|
|
||||||
|
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||||
|
|
||||||
|
|
||||||
class Module(dict):
|
class Module(dict):
|
||||||
"""Base class for building neural networks with MLX.
|
"""Base class for building neural networks with MLX.
|
||||||
|
|
||||||
@ -98,10 +134,13 @@ class Module(dict):
|
|||||||
if key in self:
|
if key in self:
|
||||||
return self[key]
|
return self[key]
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
super(Module, self).__getattr__(key, val)
|
||||||
|
|
||||||
def __setattr__(self, key: str, val: Any):
|
def __setattr__(self, key: str, val: Any):
|
||||||
|
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||||
self[key] = val
|
self[key] = val
|
||||||
|
else:
|
||||||
|
super(Module, self).__setattr__(key, val)
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
@ -245,31 +284,11 @@ class Module(dict):
|
|||||||
is_leaf_fn = is_leaf_fn or (
|
is_leaf_fn = is_leaf_fn or (
|
||||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
def unwrap(vk, v):
|
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||||
if is_leaf_fn(self, vk, v):
|
for k, v in self.items()
|
||||||
return map_fn(v)
|
if filter_fn(self, k, v)
|
||||||
|
}
|
||||||
if isinstance(v, Module):
|
|
||||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
|
||||||
|
|
||||||
if isinstance(v, dict):
|
|
||||||
nd = {}
|
|
||||||
for k, v in v.items():
|
|
||||||
tk = f"{vk}.{k}"
|
|
||||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
|
||||||
return nd
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
nl = []
|
|
||||||
for i, vi in enumerate(v):
|
|
||||||
tk = f"{vk}.{i}"
|
|
||||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
|
||||||
return nl
|
|
||||||
|
|
||||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
|
||||||
|
|
||||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
|
||||||
|
|
||||||
def parameters(self):
|
def parameters(self):
|
||||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||||
|
Loading…
Reference in New Issue
Block a user