mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Tmp
This commit is contained in:
		| @@ -7,7 +7,7 @@ import mlx.core as mx | |||||||
| from mlx.utils import tree_flatten, tree_unflatten | from mlx.utils import tree_flatten, tree_unflatten | ||||||
|  |  | ||||||
|  |  | ||||||
| class Module(dict): | class Module: | ||||||
|     """Base class for building neural networks with MLX. |     """Base class for building neural networks with MLX. | ||||||
|  |  | ||||||
|     All the layers provided in :mod:`mlx.nn.layers` subclass this class and |     All the layers provided in :mod:`mlx.nn.layers` subclass this class and | ||||||
| @@ -58,6 +58,9 @@ class Module(dict): | |||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """Should be called by the subclasses of ``Module``.""" |         """Should be called by the subclasses of ``Module``.""" | ||||||
|  |         # Initialize _keys to implement __setattr__ | ||||||
|  |         super().__setattr__("_keys", set()) | ||||||
|  |  | ||||||
|         self._no_grad = set() |         self._no_grad = set() | ||||||
|         self._training = True |         self._training = True | ||||||
|  |  | ||||||
| @@ -81,14 +84,29 @@ class Module(dict): | |||||||
|  |  | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def __getattr__(self, key: str): |  | ||||||
|         if key in self: |  | ||||||
|             return self[key] |  | ||||||
|         else: |  | ||||||
|             raise AttributeError(f"{type(self)!r} has no attribute {key!r}") |  | ||||||
|  |  | ||||||
|     def __setattr__(self, key: str, val: Any): |     def __setattr__(self, key: str, val: Any): | ||||||
|         self[key] = val |         if not key.startswith("_"): | ||||||
|  |             self._keys.add(key) | ||||||
|  |         super().__setattr__(key, val) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, key: str): | ||||||
|  |         if key not in self._keys: | ||||||
|  |             raise KeyError(key) | ||||||
|  |         return getattr(self, key) | ||||||
|  |  | ||||||
|  |     def __setitem__(self, key: str, val: Any): | ||||||
|  |         if key not in self._keys: | ||||||
|  |             raise KeyError(key) | ||||||
|  |         setattr(self, key, val) | ||||||
|  |  | ||||||
|  |     def __contains__(self, key: str): | ||||||
|  |         return key in self._keys | ||||||
|  |  | ||||||
|  |     def keys(self): | ||||||
|  |         return (k for k in self._keys) | ||||||
|  |  | ||||||
|  |     def items(self): | ||||||
|  |         return ((k, self[k]) for k in self._keys) | ||||||
|  |  | ||||||
|     def load_weights( |     def load_weights( | ||||||
|         self, |         self, | ||||||
| @@ -190,11 +208,13 @@ class Module(dict): | |||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def valid_child_filter(module, key, value): |     def valid_child_filter(module, key, value): | ||||||
|         return isinstance(value, (dict, list)) |         return isinstance(value, (Module, dict, list)) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def valid_parameter_filter(module, key, value): |     def valid_parameter_filter(module, key, value): | ||||||
|         return isinstance(value, (dict, list, mx.array)) and not key.startswith("_") |         return isinstance(value, (Module, dict, list, mx.array)) and not key.startswith( | ||||||
|  |             "_" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def trainable_parameter_filter(module, key, value): |     def trainable_parameter_filter(module, key, value): | ||||||
| @@ -203,6 +223,13 @@ class Module(dict): | |||||||
|             and key not in module._no_grad |             and key not in module._no_grad | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def non_trainable_parameter_filter(module, key, value): | ||||||
|  |         return not key.startswith("_") and ( | ||||||
|  |             isinstance(value, (Module, dict, list)) | ||||||
|  |             or (isinstance(value, mx.array) and key in module._no_grad) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def filter_and_map( |     def filter_and_map( | ||||||
|         self, |         self, | ||||||
|         filter_fn: Callable[["mlx.nn.Module", str, Any], bool], |         filter_fn: Callable[["mlx.nn.Module", str, Any], bool], | ||||||
| @@ -268,6 +295,11 @@ class Module(dict): | |||||||
|         this Module as a dict of dicts and lists.""" |         this Module as a dict of dicts and lists.""" | ||||||
|         return self.filter_and_map(self.trainable_parameter_filter) |         return self.filter_and_map(self.trainable_parameter_filter) | ||||||
|  |  | ||||||
|  |     def non_trainable_parameters(self): | ||||||
|  |         """Recursively return all the frozen :class:`mlx.core.array` members of | ||||||
|  |         this Module as a dict of dicts and lists.""" | ||||||
|  |         return self.filter_and_map(self.non_trainable_parameter_filter) | ||||||
|  |  | ||||||
|     def children(self): |     def children(self): | ||||||
|         """Return the direct descendants of this Module instance.""" |         """Return the direct descendants of this Module instance.""" | ||||||
|         return self.filter_and_map( |         return self.filter_and_map( | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| // Copyright © 2023-2024 Apple Inc. | // Copyright © 2023-2024 Apple Inc. | ||||||
|  |  | ||||||
| #include <pybind11/functional.h> | #include <pybind11/functional.h> | ||||||
| #include <pybind11/pybind11.h> | #include <pybind11/pybind11.h> | ||||||
| #include <pybind11/stl.h> | #include <pybind11/stl.h> | ||||||
| @@ -485,7 +486,7 @@ struct PyCompiledFun { | |||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     // Inputs must be array or tree of arrays |     // Inputs must be array or tree of arrays | ||||||
|     auto inputs = tree_flatten(args, true); |     auto inputs = tree_flatten(args, false); | ||||||
|  |  | ||||||
|     // Get globally enclosed arrays so we don't compile through them |     // Get globally enclosed arrays so we don't compile through them | ||||||
|     // c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638 |     // c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos