mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Tmp
This commit is contained in:
parent
1e4d3c7fb2
commit
f5ece1d98d
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
class Module(dict):
|
class Module:
|
||||||
"""Base class for building neural networks with MLX.
|
"""Base class for building neural networks with MLX.
|
||||||
|
|
||||||
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
||||||
@ -58,6 +58,9 @@ class Module(dict):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Should be called by the subclasses of ``Module``."""
|
"""Should be called by the subclasses of ``Module``."""
|
||||||
|
# Initialize _keys to implement __setattr__
|
||||||
|
super().__setattr__("_keys", set())
|
||||||
|
|
||||||
self._no_grad = set()
|
self._no_grad = set()
|
||||||
self._training = True
|
self._training = True
|
||||||
|
|
||||||
@ -81,14 +84,29 @@ class Module(dict):
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
|
||||||
if key in self:
|
|
||||||
return self[key]
|
|
||||||
else:
|
|
||||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
|
||||||
|
|
||||||
def __setattr__(self, key: str, val: Any):
|
def __setattr__(self, key: str, val: Any):
|
||||||
self[key] = val
|
if not key.startswith("_"):
|
||||||
|
self._keys.add(key)
|
||||||
|
super().__setattr__(key, val)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
if key not in self._keys:
|
||||||
|
raise KeyError(key)
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, val: Any):
|
||||||
|
if key not in self._keys:
|
||||||
|
raise KeyError(key)
|
||||||
|
setattr(self, key, val)
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self._keys
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return (k for k in self._keys)
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return ((k, self[k]) for k in self._keys)
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
self,
|
self,
|
||||||
@ -190,11 +208,13 @@ class Module(dict):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def valid_child_filter(module, key, value):
|
def valid_child_filter(module, key, value):
|
||||||
return isinstance(value, (dict, list))
|
return isinstance(value, (Module, dict, list))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def valid_parameter_filter(module, key, value):
|
def valid_parameter_filter(module, key, value):
|
||||||
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith(
|
||||||
|
"_"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trainable_parameter_filter(module, key, value):
|
def trainable_parameter_filter(module, key, value):
|
||||||
@ -203,6 +223,13 @@ class Module(dict):
|
|||||||
and key not in module._no_grad
|
and key not in module._no_grad
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def non_trainable_parameter_filter(module, key, value):
|
||||||
|
return not key.startswith("_") and (
|
||||||
|
isinstance(value, (Module, dict, list))
|
||||||
|
or (isinstance(value, mx.array) and key in module._no_grad)
|
||||||
|
)
|
||||||
|
|
||||||
def filter_and_map(
|
def filter_and_map(
|
||||||
self,
|
self,
|
||||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||||
@ -268,6 +295,11 @@ class Module(dict):
|
|||||||
this Module as a dict of dicts and lists."""
|
this Module as a dict of dicts and lists."""
|
||||||
return self.filter_and_map(self.trainable_parameter_filter)
|
return self.filter_and_map(self.trainable_parameter_filter)
|
||||||
|
|
||||||
|
def non_trainable_parameters(self):
|
||||||
|
"""Recursively return all the frozen :class:`mlx.core.array` members of
|
||||||
|
this Module as a dict of dicts and lists."""
|
||||||
|
return self.filter_and_map(self.non_trainable_parameter_filter)
|
||||||
|
|
||||||
def children(self):
|
def children(self):
|
||||||
"""Return the direct descendants of this Module instance."""
|
"""Return the direct descendants of this Module instance."""
|
||||||
return self.filter_and_map(
|
return self.filter_and_map(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
@ -485,7 +486,7 @@ struct PyCompiledFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Inputs must be array or tree of arrays
|
// Inputs must be array or tree of arrays
|
||||||
auto inputs = tree_flatten(args, true);
|
auto inputs = tree_flatten(args, false);
|
||||||
|
|
||||||
// Get globally enclosed arrays so we don't compile through them
|
// Get globally enclosed arrays so we don't compile through them
|
||||||
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
|
// c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638
|
||||||
|
Loading…
Reference in New Issue
Block a user