mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Tmp
This commit is contained in:
parent
1e4d3c7fb2
commit
f5ece1d98d
@ -7,7 +7,7 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
class Module(dict):
|
||||
class Module:
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||
@ -58,6 +58,9 @@ class Module(dict):
|
||||
|
||||
def __init__(self):
|
||||
"""Should be called by the subclasses of ``Module``."""
|
||||
# Initialize _keys to implement __setattr__
|
||||
super().__setattr__("_keys", set())
|
||||
|
||||
self._no_grad = set()
|
||||
self._training = True
|
||||
|
||||
@ -81,14 +84,29 @@ class Module(dict):
|
||||
|
||||
return value
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
if not key.startswith("_"):
|
||||
self._keys.add(key)
|
||||
super().__setattr__(key, val)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key not in self._keys:
|
||||
raise KeyError(key)
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key: str, val: Any):
|
||||
if key not in self._keys:
|
||||
raise KeyError(key)
|
||||
setattr(self, key, val)
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._keys
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self._keys)
|
||||
|
||||
def items(self):
|
||||
return ((k, self[k]) for k in self._keys)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@ -190,11 +208,13 @@ class Module(dict):
|
||||
|
||||
@staticmethod
|
||||
def valid_child_filter(module, key, value):
|
||||
return isinstance(value, (dict, list))
|
||||
return isinstance(value, (Module, dict, list))
|
||||
|
||||
@staticmethod
|
||||
def valid_parameter_filter(module, key, value):
|
||||
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
||||
return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith(
|
||||
"_"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def trainable_parameter_filter(module, key, value):
|
||||
@ -203,6 +223,13 @@ class Module(dict):
|
||||
and key not in module._no_grad
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def non_trainable_parameter_filter(module, key, value):
|
||||
return not key.startswith("_") and (
|
||||
isinstance(value, (Module, dict, list))
|
||||
or (isinstance(value, mx.array) and key in module._no_grad)
|
||||
)
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
@ -268,6 +295,11 @@ class Module(dict):
|
||||
this Module as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.trainable_parameter_filter)
|
||||
|
||||
def non_trainable_parameters(self):
|
||||
"""Recursively return all the frozen :class:`mlx.core.array` members of
|
||||
this Module as a dict of dicts and lists."""
|
||||
return self.filter_and_map(self.non_trainable_parameter_filter)
|
||||
|
||||
def children(self):
|
||||
"""Return the direct descendants of this Module instance."""
|
||||
return self.filter_and_map(
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
@ -485,7 +486,7 @@ struct PyCompiledFun {
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto inputs = tree_flatten(args, false);
|
||||
|
||||
// Get globally enclosed arrays so we don't compile through them
|
||||
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
|
||||
|
Loading…
Reference in New Issue
Block a user