mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Tmp
This commit is contained in:
		| @@ -7,7 +7,7 @@ import mlx.core as mx | ||||
| from mlx.utils import tree_flatten, tree_unflatten | ||||
|  | ||||
|  | ||||
| class Module(dict): | ||||
| class Module: | ||||
|     """Base class for building neural networks with MLX. | ||||
|  | ||||
|     All the layers provided in :mod:`mlx.nn.layers` subclass this class and | ||||
| @@ -58,6 +58,9 @@ class Module(dict): | ||||
|  | ||||
|     def __init__(self): | ||||
|         """Should be called by the subclasses of ``Module``.""" | ||||
|         # Initialize _keys to implement __setattr__ | ||||
|         super().__setattr__("_keys", set()) | ||||
|  | ||||
|         self._no_grad = set() | ||||
|         self._training = True | ||||
|  | ||||
| @@ -81,14 +84,29 @@ class Module(dict): | ||||
|  | ||||
|         return value | ||||
|  | ||||
|     def __getattr__(self, key: str): | ||||
|         if key in self: | ||||
|             return self[key] | ||||
|         else: | ||||
|             raise AttributeError(f"{type(self)!r} has no attribute {key!r}") | ||||
|  | ||||
|     def __setattr__(self, key: str, val: Any): | ||||
|         self[key] = val | ||||
|         if not key.startswith("_"): | ||||
|             self._keys.add(key) | ||||
|         super().__setattr__(key, val) | ||||
|  | ||||
|     def __getitem__(self, key: str): | ||||
|         if key not in self._keys: | ||||
|             raise KeyError(key) | ||||
|         return getattr(self, key) | ||||
|  | ||||
|     def __setitem__(self, key: str, val: Any): | ||||
|         if key not in self._keys: | ||||
|             raise KeyError(key) | ||||
|         setattr(self, key, val) | ||||
|  | ||||
|     def __contains__(self, key: str): | ||||
|         return key in self._keys | ||||
|  | ||||
|     def keys(self): | ||||
|         return (k for k in self._keys) | ||||
|  | ||||
|     def items(self): | ||||
|         return ((k, self[k]) for k in self._keys) | ||||
|  | ||||
|     def load_weights( | ||||
|         self, | ||||
| @@ -190,11 +208,13 @@ class Module(dict): | ||||
|  | ||||
|     @staticmethod | ||||
|     def valid_child_filter(module, key, value): | ||||
|         return isinstance(value, (dict, list)) | ||||
|         return isinstance(value, (Module, dict, list)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def valid_parameter_filter(module, key, value): | ||||
|         return isinstance(value, (dict, list, mx.array)) and not key.startswith("_") | ||||
|         return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith( | ||||
|             "_" | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def trainable_parameter_filter(module, key, value): | ||||
| @@ -203,6 +223,13 @@ class Module(dict): | ||||
|             and key not in module._no_grad | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def non_trainable_parameter_filter(module, key, value): | ||||
|         return not key.startswith("_") and ( | ||||
|             isinstance(value, (Module, dict, list)) | ||||
|             or (isinstance(value, mx.array) and key in module._no_grad) | ||||
|         ) | ||||
|  | ||||
|     def filter_and_map( | ||||
|         self, | ||||
|         filter_fn: Callable[["mlx.nn.Module", str, Any], bool], | ||||
| @@ -268,6 +295,11 @@ class Module(dict): | ||||
|         this Module as a dict of dicts and lists.""" | ||||
|         return self.filter_and_map(self.trainable_parameter_filter) | ||||
|  | ||||
|     def non_trainable_parameters(self): | ||||
|         """Recursively return all the frozen :class:`mlx.core.array` members of | ||||
|         this Module as a dict of dicts and lists.""" | ||||
|         return self.filter_and_map(self.non_trainable_parameter_filter) | ||||
|  | ||||
|     def children(self): | ||||
|         """Return the direct descendants of this Module instance.""" | ||||
|         return self.filter_and_map( | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <pybind11/functional.h> | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
| @@ -485,7 +486,7 @@ struct PyCompiledFun { | ||||
|     }; | ||||
|  | ||||
|     // Inputs must be array or tree of arrays | ||||
|     auto inputs = tree_flatten(args, true); | ||||
|     auto inputs = tree_flatten(args, false); | ||||
|  | ||||
|     // Get globally enclosed arrays so we don't compile through them | ||||
|     // c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos