mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -1,19 +0,0 @@ | |||||||
| {{ fullname | escape | underline}} |  | ||||||
|  |  | ||||||
| .. currentmodule:: {{ module }} |  | ||||||
|  |  | ||||||
| .. autoclass:: {{ objname }} |  | ||||||
|  |  | ||||||
|    {#{% block methods %} |  | ||||||
|  |  | ||||||
|    {% if methods %} |  | ||||||
|    .. rubric:: {{ _('Methods') }} |  | ||||||
|  |  | ||||||
|    .. autosummary:: |  | ||||||
|    {% for item in methods %} |  | ||||||
|       {%- if item not in inherited_members and item != '__init__' %} |  | ||||||
|          ~{{ name }}.{{ item }} |  | ||||||
|       {%- endif %} |  | ||||||
|    {%- endfor %} |  | ||||||
|    {% endif %} |  | ||||||
|    {% endblock %}#} |  | ||||||
| @@ -11,6 +11,7 @@ Module | |||||||
|       :toctree: _autosummary |       :toctree: _autosummary | ||||||
|     |     | ||||||
|       Module.training |       Module.training | ||||||
|  |       Module.state | ||||||
|     |     | ||||||
|    .. rubric:: Methods |    .. rubric:: Methods | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										23
									
								
								docs/src/python/optimizer.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								docs/src/python/optimizer.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | Optimizer | ||||||
|  | ========= | ||||||
|  |  | ||||||
|  | .. currentmodule:: mlx.optimizers | ||||||
|  |  | ||||||
|  | .. autoclass:: Optimizer  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |    .. rubric:: Attributes | ||||||
|  |  | ||||||
|  |    .. autosummary:: | ||||||
|  |       :toctree: _autosummary | ||||||
|  |  | ||||||
|  |       Optimizer.state | ||||||
|  |     | ||||||
|  |    .. rubric:: Methods | ||||||
|  |  | ||||||
|  |    .. autosummary:: | ||||||
|  |       :toctree: _autosummary | ||||||
|  |     | ||||||
|  |       Optimizer.apply_gradients | ||||||
|  |       Optimizer.init | ||||||
|  |       Optimizer.update | ||||||
| @@ -29,14 +29,16 @@ model's parameters and the **optimizer state**. | |||||||
|             # Compute the new parameters but also the optimizer state. |             # Compute the new parameters but also the optimizer state. | ||||||
|             mx.eval(model.parameters(), optimizer.state) |             mx.eval(model.parameters(), optimizer.state) | ||||||
|  |  | ||||||
|  | .. toctree:: | ||||||
|  |  | ||||||
|  |    optimizer | ||||||
|  |  | ||||||
| .. currentmodule:: mlx.optimizers | .. currentmodule:: mlx.optimizers | ||||||
|  |  | ||||||
| .. autosummary:: | .. autosummary:: | ||||||
|    :toctree: _autosummary |    :toctree: _autosummary | ||||||
|    :template: optimizers-template.rst |    :template: optimizers-template.rst | ||||||
|  |  | ||||||
|    OptimizerState |  | ||||||
|    Optimizer |  | ||||||
|    SGD |    SGD | ||||||
|    RMSprop |    RMSprop | ||||||
|    Adagrad |    Adagrad | ||||||
|   | |||||||
| @@ -191,10 +191,7 @@ struct CompilerCache { | |||||||
|     auto is_match = [](const std::vector<array>& in1, |     auto is_match = [](const std::vector<array>& in1, | ||||||
|                        const std::vector<array>& in2) { |                        const std::vector<array>& in2) { | ||||||
|       if (in1.size() != in2.size()) { |       if (in1.size() != in2.size()) { | ||||||
|         std::ostringstream msg; |         return false; | ||||||
|         msg << "[compiler] Unexpected number of inputs to compiled function:" |  | ||||||
|             << " expected " << in2.size() << " got " << in1.size() << "."; |  | ||||||
|         throw std::invalid_argument(msg.str()); |  | ||||||
|       } |       } | ||||||
|       for (int i = 0; i < in1.size(); ++i) { |       for (int i = 0; i < in1.size(); ++i) { | ||||||
|         if (in1[i].shape() != in2[i].shape()) { |         if (in1[i].shape() != in2[i].shape()) { | ||||||
| @@ -603,7 +600,7 @@ void compile_fuse( | |||||||
|         shapes, |         shapes, | ||||||
|         types, |         types, | ||||||
|         std::make_shared<Compiled>( |         std::make_shared<Compiled>( | ||||||
|             outputs.back().primitive().stream(), |             old_outputs.back().primitive().stream(), | ||||||
|             inputs, |             inputs, | ||||||
|             old_outputs, |             old_outputs, | ||||||
|             std::move(fused_tape), |             std::move(fused_tape), | ||||||
|   | |||||||
| @@ -66,6 +66,19 @@ class Module(dict): | |||||||
|         """Boolean indicating if the model is in training mode.""" |         """Boolean indicating if the model is in training mode.""" | ||||||
|         return self._training |         return self._training | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def state(self): | ||||||
|  |         """The module's state dictionary | ||||||
|  |  | ||||||
|  |         The module's state dictionary contains any attribute set on the | ||||||
|  |         module including parameters in :meth:`Module.parameters` | ||||||
|  |  | ||||||
|  |         Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is | ||||||
|  |         a reference to the module's state. Updates to it will be reflected in | ||||||
|  |         the original module. | ||||||
|  |         """ | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def _extra_repr(self): |     def _extra_repr(self): | ||||||
|         return "" |         return "" | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,39 +7,14 @@ import mlx.core as mx | |||||||
| from mlx.utils import tree_map | from mlx.utils import tree_map | ||||||
|  |  | ||||||
|  |  | ||||||
| class OptimizerState(dict): |  | ||||||
|     """The optimizer state implements a recursively defined |  | ||||||
|     :class:`collections.defaultdict`, namely a missing key in an optimizer |  | ||||||
|     state is an :class:`OptimizerState`. |  | ||||||
|  |  | ||||||
|     .. note:: |  | ||||||
|        :meth:`OptimizerState.get` in contrast to a normal dictionary also sets |  | ||||||
|        the key to the ``default`` value if the ``key`` was not present in the |  | ||||||
|        dictionary. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __getitem__(self, key): |  | ||||||
|         if key not in self: |  | ||||||
|             self[key] = OptimizerState() |  | ||||||
|         return super().__getitem__(key) |  | ||||||
|  |  | ||||||
|     def get(self, key, default): |  | ||||||
|         """If ``key`` doesn't exist set its value to ``default`` and then return it.""" |  | ||||||
|         if key not in self: |  | ||||||
|             self[key] = default |  | ||||||
|         return super().__getitem__(key) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Optimizer: | class Optimizer: | ||||||
|     """The base class for all optimizers. It allows us to implement an |     """The base class for all optimizers. It allows us to implement an | ||||||
|     optimizer on a per-parameter basis and apply it to a parameter tree. |     optimizer on a per-parameter basis and apply it to a parameter tree. | ||||||
|  |  | ||||||
|     Attributes: |  | ||||||
|         state (OptimizerState): It holds the optimizer's state dictionary. |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.state = OptimizerState() |         self._initialized = False | ||||||
|  |         self._state = {} | ||||||
|  |  | ||||||
|     def update(self, model: "mlx.nn.Module", gradients: dict): |     def update(self, model: "mlx.nn.Module", gradients: dict): | ||||||
|         """Apply the gradients to the parameters of the model and update the |         """Apply the gradients to the parameters of the model and update the | ||||||
| @@ -52,7 +27,41 @@ class Optimizer: | |||||||
|         """ |         """ | ||||||
|         model.update(self.apply_gradients(gradients, model)) |         model.update(self.apply_gradients(gradients, model)) | ||||||
|  |  | ||||||
|     def apply_gradients(self, gradients: dict, model: dict): |     def init(self, parameters: dict): | ||||||
|  |         """Initialize the optimizer's state | ||||||
|  |  | ||||||
|  |         This function can be used to initialize optimizers which have state | ||||||
|  |         (like momentum in :class:`SGD`). Using this method is optional as the | ||||||
|  |         optimizer will initialize itself if the state is not yet set. However, | ||||||
|  |         there are some cases where explicit initialization is useful in order | ||||||
|  |         to have access to the :attr:`Optimizer.state` before the first call to | ||||||
|  |         :meth:`Optimizer.update`. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             model (dict): A Python tree of parameters. | ||||||
|  |  | ||||||
|  |         Example: | ||||||
|  |             >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) | ||||||
|  |             >>> model = nn.Linear(2, 2) | ||||||
|  |             >>> optimizer.init(model.trainable_parameters()) | ||||||
|  |             >>> optimizer.state | ||||||
|  |             {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], | ||||||
|  |                    [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} | ||||||
|  |         """ | ||||||
|  |         self._state.update(tree_map(lambda x: {}, parameters)) | ||||||
|  |         tree_map(self.init_single, parameters, self._state) | ||||||
|  |         self._initialized = True | ||||||
|  |  | ||||||
|  |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|  |         """To be extended by the children classes to implement each optimizer's | ||||||
|  |         state initialization. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             parameter (mx.array): A single parameter that will be optimized. | ||||||
|  |         """ | ||||||
|  |         raise NotImplementedError() | ||||||
|  |  | ||||||
|  |     def apply_gradients(self, gradients: dict, parameters: dict): | ||||||
|         """Apply the gradients to the parameters and return the updated parameters. |         """Apply the gradients to the parameters and return the updated parameters. | ||||||
|  |  | ||||||
|         Can be used to update a model via |         Can be used to update a model via | ||||||
| @@ -61,19 +70,41 @@ class Optimizer: | |||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             gradients (dict): A Python tree of gradients. |             gradients (dict): A Python tree of gradients. | ||||||
|             model (dict): A Python tree of parameters. It can be a superset of |             parameters (dict): A Python tree of parameters. It can be a | ||||||
|                           the gradients. In that case the returned python tree |               superset of the gradients. In that case the returned python | ||||||
|                           will be of the same structure as the gradients. |               tree will be of the same structure as the gradients. | ||||||
|         """ |         """ | ||||||
|         return tree_map(self.apply_single, gradients, model, self.state) |         if not self._initialized: | ||||||
|  |             self.init(gradients) | ||||||
|  |         return tree_map(self.apply_single, gradients, parameters, self.state) | ||||||
|  |  | ||||||
|     def apply_single( |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """To be extended by derived classes to implement the optimizer's update. | ||||||
|     ): |  | ||||||
|         """To be extended by the children classes to implement each optimizer's |         Args: | ||||||
|         update.""" |             gradient (mx.array): The ``parameter`` gradient. | ||||||
|  |             parameter (mx.array): The ``parameter`` to update. | ||||||
|  |             state (dict): The optimizer's state. | ||||||
|  |         """ | ||||||
|         raise NotImplementedError() |         raise NotImplementedError() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def state(self): | ||||||
|  |         """The optimizer's state dictionary.""" | ||||||
|  |         return self._state | ||||||
|  |  | ||||||
|  |     @state.setter | ||||||
|  |     def state(self, state: dict): | ||||||
|  |         self._state = state | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def learning_rate(self): | ||||||
|  |         return self.state["learning_rate"] | ||||||
|  |  | ||||||
|  |     @learning_rate.setter | ||||||
|  |     def learning_rate(self, learning_rate: mx.array): | ||||||
|  |         self.state["learning_rate"] = mx.array(learning_rate) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SGD(Optimizer): | class SGD(Optimizer): | ||||||
|     r"""The stochastic gradient descent optimizer. |     r"""The stochastic gradient descent optimizer. | ||||||
| @@ -113,9 +144,11 @@ class SGD(Optimizer): | |||||||
|         self.dampening = dampening |         self.dampening = dampening | ||||||
|         self.nesterov = nesterov |         self.nesterov = nesterov | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the SGD parameter update and stores :math:`v` in the |         """Performs the SGD parameter update and stores :math:`v` in the | ||||||
|         optimizer state.""" |         optimizer state.""" | ||||||
|  |  | ||||||
| @@ -123,24 +156,21 @@ class SGD(Optimizer): | |||||||
|             gradient += self.weight_decay * parameter |             gradient += self.weight_decay * parameter | ||||||
|  |  | ||||||
|         if self.momentum <= 0: |         if self.momentum <= 0: | ||||||
|             return parameter - self.learning_rate * gradient |             return parameter - self.learning_rate.astype(gradient.dtype) * gradient | ||||||
|  |  | ||||||
|  |         v = self.momentum * state.get("v") | ||||||
|         if self.dampening > 0: |         if self.dampening > 0: | ||||||
|             v = ( |  | ||||||
|                 state.get("v", (self.dampening / self.momentum) * gradient) |  | ||||||
|                 * self.momentum |  | ||||||
|             ) |  | ||||||
|             v += (1 - self.dampening) * gradient |             v += (1 - self.dampening) * gradient | ||||||
|         else: |         else: | ||||||
|             v = state.get("v", mx.zeros_like(gradient)) * self.momentum |  | ||||||
|             v += gradient |             v += gradient | ||||||
|  |  | ||||||
|         if self.nesterov: |         if self.nesterov: | ||||||
|             update = gradient + self.momentum * v |             update = gradient + self.momentum * v | ||||||
|         else: |         else: | ||||||
|             update = v |             update = v | ||||||
|  |  | ||||||
|         state["v"] = v |         state["v"] = v | ||||||
|         return parameter - self.learning_rate * update |         return parameter - self.learning_rate.astype(gradient.dtype) * update | ||||||
|  |  | ||||||
|  |  | ||||||
| class RMSprop(Optimizer): | class RMSprop(Optimizer): | ||||||
| @@ -177,15 +207,17 @@ class RMSprop(Optimizer): | |||||||
|                 f"RMSprop epsilon should be >0, {self.eps} was provided instead" |                 f"RMSprop epsilon should be >0, {self.eps} was provided instead" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" |         """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         alpha = self.alpha |         alpha = self.alpha | ||||||
|         eps = self.eps |         eps = self.eps | ||||||
|  |  | ||||||
|         v = state.get("v", mx.zeros_like(gradient)) |         v = state["v"] | ||||||
|         v = alpha * v + (1 - alpha) * mx.square(gradient) |         v = alpha * v + (1 - alpha) * mx.square(gradient) | ||||||
|         state["v"] = v |         state["v"] = v | ||||||
|  |  | ||||||
| @@ -222,16 +254,17 @@ class Adagrad(Optimizer): | |||||||
|                 f"Adagrad epsilon should be >0, {self.eps} was provided instead" |                 f"Adagrad epsilon should be >0, {self.eps} was provided instead" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the Adagrad parameter update and stores :math:`v` in the |         """Performs the Adagrad parameter update and stores :math:`v` in the | ||||||
|         optimizer state.""" |         optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         eps = self.eps |         eps = self.eps | ||||||
|  |  | ||||||
|         v = state.get("v", mx.zeros_like(gradient)) |         v = state["v"] + mx.square(gradient) | ||||||
|         v = v + mx.square(gradient) |  | ||||||
|         state["v"] = v |         state["v"] = v | ||||||
|  |  | ||||||
|         return parameter - lr * gradient / (mx.sqrt(v) + eps) |         return parameter - lr * gradient / (mx.sqrt(v) + eps) | ||||||
| @@ -274,17 +307,20 @@ class AdaDelta(Optimizer): | |||||||
|                 f"AdaDelta epsilon should be >0, {self.eps} was provided instead" |                 f"AdaDelta epsilon should be >0, {self.eps} was provided instead" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |         state["u"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the AdaDelta parameter update and stores :math:`v` and |         """Performs the AdaDelta parameter update and stores :math:`v` and | ||||||
|         :math:`u` in the optimizer state.""" |         :math:`u` in the optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         rho = self.rho |         rho = self.rho | ||||||
|         eps = self.eps |         eps = self.eps | ||||||
|  |  | ||||||
|         v = state.get("v", mx.zeros_like(gradient)) |         v = state["v"] | ||||||
|         u = state.get("u", mx.zeros_like(gradient)) |         u = state["u"] | ||||||
|  |  | ||||||
|         v = rho * v + (1 - rho) * mx.square(gradient) |         v = rho * v + (1 - rho) * mx.square(gradient) | ||||||
|         d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient |         d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient | ||||||
| @@ -329,17 +365,20 @@ class Adam(Optimizer): | |||||||
|         self.betas = betas |         self.betas = betas | ||||||
|         self.eps = eps |         self.eps = eps | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["m"] = mx.zeros_like(parameter) | ||||||
|  |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the Adam parameter update and stores :math:`v` and |         """Performs the Adam parameter update and stores :math:`v` and | ||||||
|         :math:`m` in the optimizer state.""" |         :math:`m` in the optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         b1, b2 = self.betas |         b1, b2 = self.betas | ||||||
|         eps = self.eps |         eps = self.eps | ||||||
|  |  | ||||||
|         m = state.get("m", gradient) |         m = state["m"] | ||||||
|         v = state.get("v", mx.square(gradient)) |         v = state["v"] | ||||||
|         m = b1 * m + (1 - b1) * gradient |         m = b1 * m + (1 - b1) * gradient | ||||||
|         v = b2 * v + (1 - b2) * mx.square(gradient) |         v = b2 * v + (1 - b2) * mx.square(gradient) | ||||||
|         state["m"] = m |         state["m"] = m | ||||||
| @@ -385,15 +424,14 @@ class AdamW(Adam): | |||||||
|         super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) |         super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) | ||||||
|         self.weight_decay = weight_decay |         self.weight_decay = weight_decay | ||||||
|  |  | ||||||
|     def apply_single( |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |  | ||||||
|     ): |  | ||||||
|         """Performs the AdamW parameter update by modifying the parameters |         """Performs the AdamW parameter update by modifying the parameters | ||||||
|         passed into Adam. |         passed into Adam. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         return super().apply_single( |         return super().apply_single( | ||||||
|             gradient, parameter * (1 - self.learning_rate * self.weight_decay), state |             gradient, parameter * (1 - lr * self.weight_decay), state | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -430,17 +468,20 @@ class Adamax(Adam): | |||||||
|                 f"Epsilon value should be >=0, {self.eps} was provided instead" |                 f"Epsilon value should be >=0, {self.eps} was provided instead" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["m"] = mx.zeros_like(parameter) | ||||||
|  |         state["v"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the Adamax parameter update and stores :math:`v` and |         """Performs the Adamax parameter update and stores :math:`v` and | ||||||
|         :math:`m` in the optimizer state.""" |         :math:`m` in the optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         b1, b2 = self.betas |         b1, b2 = self.betas | ||||||
|         eps = self.eps |         eps = self.eps | ||||||
|  |  | ||||||
|         m = state.get("m", mx.zeros_like(gradient)) |         m = state["m"] | ||||||
|         v = state.get("v", mx.zeros_like(gradient)) |         v = state["v"] | ||||||
|  |  | ||||||
|         m = b1 * m + (1 - b1) * gradient |         m = b1 * m + (1 - b1) * gradient | ||||||
|         v = mx.maximum(b2 * v, mx.abs(gradient)) |         v = mx.maximum(b2 * v, mx.abs(gradient)) | ||||||
| @@ -489,16 +530,18 @@ class Lion(Optimizer): | |||||||
|         self.betas = betas |         self.betas = betas | ||||||
|         self.weight_decay = weight_decay |         self.weight_decay = weight_decay | ||||||
|  |  | ||||||
|     def apply_single( |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |         """Initialize optimizer state""" | ||||||
|     ): |         state["m"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         """Performs the Lion parameter update and stores :math:`m` |         """Performs the Lion parameter update and stores :math:`m` | ||||||
|         in the optimizer state.""" |         in the optimizer state.""" | ||||||
|         lr = self.learning_rate |         lr = self.learning_rate.astype(gradient.dtype) | ||||||
|         b1, b2 = self.betas |         b1, b2 = self.betas | ||||||
|         weight_decay = self.weight_decay |         weight_decay = self.weight_decay | ||||||
|  |  | ||||||
|         m = state.get("m", gradient) |         m = state["m"] | ||||||
|         c = b1 * m + (1 - b1) * gradient |         c = b1 * m + (1 - b1) * gradient | ||||||
|         state["m"] = b2 * m + (1 - b2) * gradient |         state["m"] = b2 * m + (1 - b2) * gradient | ||||||
|         if weight_decay > 0: |         if weight_decay > 0: | ||||||
| @@ -552,7 +595,8 @@ class Adafactor(Optimizer): | |||||||
|         warmup_init: bool = False, |         warmup_init: bool = False, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.learning_rate = learning_rate |         if learning_rate is not None: | ||||||
|  |             self.learning_rate = learning_rate | ||||||
|         self.eps = eps |         self.eps = eps | ||||||
|         self.clip_threshold = clip_threshold |         self.clip_threshold = clip_threshold | ||||||
|         self.decay_rate = decay_rate |         self.decay_rate = decay_rate | ||||||
| @@ -562,14 +606,29 @@ class Adafactor(Optimizer): | |||||||
|         self.relative_step = relative_step |         self.relative_step = relative_step | ||||||
|         self.warmup_init = warmup_init |         self.warmup_init = warmup_init | ||||||
|  |  | ||||||
|  |     def init_single(self, parameter: mx.array, state: dict): | ||||||
|  |         """Initialize optimizer state""" | ||||||
|  |         state["step"] = 0 | ||||||
|  |         if parameter.ndim >= 2: | ||||||
|  |             shape = parameter.shape | ||||||
|  |             dtype = parameter.dtype | ||||||
|  |             state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype) | ||||||
|  |             state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype) | ||||||
|  |         else: | ||||||
|  |             state["exp_avg_sq"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|  |         if self.beta_1 is not None: | ||||||
|  |             state["exp_avg"] = mx.zeros_like(parameter) | ||||||
|  |  | ||||||
|     def _compute_rms(self, inputs): |     def _compute_rms(self, inputs): | ||||||
|         return mx.sqrt(mx.mean(mx.square(inputs))) |         return mx.sqrt(mx.mean(mx.square(inputs))) | ||||||
|  |  | ||||||
|     def _compute_learning_rate(self, step, parameter_rms): |     def _compute_learning_rate(self, step, parameter_rms): | ||||||
|         relative_step_size = self.learning_rate |  | ||||||
|         if self.relative_step: |         if self.relative_step: | ||||||
|             min_step = 1e-6 * step if self.warmup_init else 1e-2 |             min_step = 1e-6 * step if self.warmup_init else 1e-2 | ||||||
|             relative_step_size = min(min_step, 1 / math.sqrt(step)) |             relative_step_size = min(min_step, 1 / math.sqrt(step)) | ||||||
|  |         else: | ||||||
|  |             relative_step_size = self.learning_rate.astype(parameter_rms) | ||||||
|  |  | ||||||
|         parameter_scale = 1.0 |         parameter_scale = 1.0 | ||||||
|         if self.scale_parameter: |         if self.scale_parameter: | ||||||
| @@ -585,13 +644,11 @@ class Adafactor(Optimizer): | |||||||
|             mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) |             mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def apply_single( |     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState |  | ||||||
|     ): |  | ||||||
|         """Performs the Adafactor parameter and state update.""" |         """Performs the Adafactor parameter and state update.""" | ||||||
|         gradient_shape = gradient.shape |         factored = gradient.ndim >= 2 | ||||||
|         factored = len(gradient_shape) >= 2 |  | ||||||
|         step = state.get("step", 0) + 1 |         step = state["step"] + 1 | ||||||
|         state["step"] = step |         state["step"] = step | ||||||
|         use_first_moment = self.beta_1 is not None |         use_first_moment = self.beta_1 is not None | ||||||
|  |  | ||||||
| @@ -601,15 +658,8 @@ class Adafactor(Optimizer): | |||||||
|         update = mx.square(gradient) + self.eps[0] |         update = mx.square(gradient) + self.eps[0] | ||||||
|  |  | ||||||
|         if factored: |         if factored: | ||||||
|             exp_avg_sq_row = state.get( |             exp_avg_sq_row = state["exp_avg_sq_row"] | ||||||
|                 "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) |             exp_avg_sq_col = state["exp_avg_sq_col"] | ||||||
|             ) |  | ||||||
|             exp_avg_sq_col = state.get( |  | ||||||
|                 "exp_avg_sq_col", |  | ||||||
|                 mx.zeros( |  | ||||||
|                     gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype |  | ||||||
|                 ), |  | ||||||
|             ) |  | ||||||
|             exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( |             exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( | ||||||
|                 (1 - beta_2) * mx.mean(update, axis=-1) |                 (1 - beta_2) * mx.mean(update, axis=-1) | ||||||
|             ) |             ) | ||||||
| @@ -621,7 +671,7 @@ class Adafactor(Optimizer): | |||||||
|             update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) |             update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) | ||||||
|             update = update * gradient |             update = update * gradient | ||||||
|         else: |         else: | ||||||
|             exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) |             exp_avg_sq = state["exp_avg_sq"] | ||||||
|             exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) |             exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) | ||||||
|             state["exp_avg_sq"] = exp_avg_sq |             state["exp_avg_sq"] = exp_avg_sq | ||||||
|             update = mx.rsqrt(exp_avg_sq) * gradient |             update = mx.rsqrt(exp_avg_sq) * gradient | ||||||
| @@ -632,7 +682,7 @@ class Adafactor(Optimizer): | |||||||
|         update = learning_rate * update |         update = learning_rate * update | ||||||
|  |  | ||||||
|         if use_first_moment: |         if use_first_moment: | ||||||
|             exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) |             exp_avg = state["exp_avg"] | ||||||
|             exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) |             exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) | ||||||
|             state["exp_avg"] = exp_avg |             state["exp_avg"] = exp_avg | ||||||
|             update = exp_avg |             update = exp_avg | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| #include <pybind11/pybind11.h> | #include <pybind11/pybind11.h> | ||||||
| #include <pybind11/stl.h> | #include <pybind11/stl.h> | ||||||
|  | #include <chrono> | ||||||
|  |  | ||||||
| #include "python/src/utils.h" | #include "python/src/utils.h" | ||||||
|  |  | ||||||
| @@ -13,13 +14,55 @@ using namespace py::literals; | |||||||
| using namespace mlx::core; | using namespace mlx::core; | ||||||
| using namespace mlx::core::random; | using namespace mlx::core::random; | ||||||
|  |  | ||||||
|  | class PyKeySequence { | ||||||
|  |  public: | ||||||
|  |   explicit PyKeySequence(uint64_t seed) { | ||||||
|  |     state_.append(key(seed)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void seed(uint64_t seed) { | ||||||
|  |     state_[0] = key(seed); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   array next() { | ||||||
|  |     auto out = split(py::cast<array>(state_[0])); | ||||||
|  |     state_[0] = out.first; | ||||||
|  |     return out.second; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   py::list state() { | ||||||
|  |     return state_; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void release() { | ||||||
|  |     py::gil_scoped_acquire gil; | ||||||
|  |     state_.release().dec_ref(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |  private: | ||||||
|  |   py::list state_; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | PyKeySequence& default_key() { | ||||||
|  |   auto get_current_time_seed = []() { | ||||||
|  |     auto now = std::chrono::system_clock::now(); | ||||||
|  |     return std::chrono::duration_cast<std::chrono::milliseconds>( | ||||||
|  |                now.time_since_epoch()) | ||||||
|  |         .count(); | ||||||
|  |   }; | ||||||
|  |   static PyKeySequence ks(get_current_time_seed()); | ||||||
|  |   return ks; | ||||||
|  | } | ||||||
|  |  | ||||||
| void init_random(py::module_& parent_module) { | void init_random(py::module_& parent_module) { | ||||||
|   auto m = parent_module.def_submodule( |   auto m = parent_module.def_submodule( | ||||||
|       "random", |       "random", | ||||||
|       "mlx.core.random: functionality related to random number generation"); |       "mlx.core.random: functionality related to random number generation"); | ||||||
|  |  | ||||||
|  |   m.attr("state") = default_key().state(); | ||||||
|   m.def( |   m.def( | ||||||
|       "seed", |       "seed", | ||||||
|       &seed, |       [](uint64_t seed) { default_key().seed(seed); }, | ||||||
|       "seed"_a, |       "seed"_a, | ||||||
|       R"pbdoc( |       R"pbdoc( | ||||||
|         Seed the global PRNG. |         Seed the global PRNG. | ||||||
| @@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) { | |||||||
|          const ScalarOrArray& high, |          const ScalarOrArray& high, | ||||||
|          const std::vector<int>& shape, |          const std::vector<int>& shape, | ||||||
|          std::optional<Dtype> type, |          std::optional<Dtype> type, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         return uniform( |         return uniform( | ||||||
|             to_array(low), |             to_array(low), | ||||||
|             to_array(high), |             to_array(high), | ||||||
| @@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) { | |||||||
|          std::optional<Dtype> type, |          std::optional<Dtype> type, | ||||||
|          float loc, |          float loc, | ||||||
|          float scale, |          float scale, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         return normal(shape, type.value_or(float32), loc, scale, key, s); |         return normal(shape, type.value_or(float32), loc, scale, key, s); | ||||||
|       }, |       }, | ||||||
|  |  | ||||||
|       "shape"_a = std::vector<int>{}, |       "shape"_a = std::vector<int>{}, | ||||||
|       "dtype"_a = std::optional{float32}, |       "dtype"_a = std::optional{float32}, | ||||||
|       "loc"_a = 0.0, |       "loc"_a = 0.0, | ||||||
| @@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) { | |||||||
|          const ScalarOrArray& high, |          const ScalarOrArray& high, | ||||||
|          const std::vector<int>& shape, |          const std::vector<int>& shape, | ||||||
|          std::optional<Dtype> type, |          std::optional<Dtype> type, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         return randint( |         return randint( | ||||||
|             to_array(low), to_array(high), shape, type.value_or(int32), key, s); |             to_array(low), to_array(high), shape, type.value_or(int32), key, s); | ||||||
|       }, |       }, | ||||||
| @@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) { | |||||||
|       "bernoulli", |       "bernoulli", | ||||||
|       [](const ScalarOrArray& p_, |       [](const ScalarOrArray& p_, | ||||||
|          const std::optional<std::vector<int>> shape, |          const std::optional<std::vector<int>> shape, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         auto p = to_array(p_); |         auto p = to_array(p_); | ||||||
|         if (shape.has_value()) { |         if (shape.has_value()) { | ||||||
|           return bernoulli(p, shape.value(), key, s); |           return bernoulli(p, shape.value(), key, s); | ||||||
| @@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) { | |||||||
|          const ScalarOrArray& upper_, |          const ScalarOrArray& upper_, | ||||||
|          const std::optional<std::vector<int>> shape_, |          const std::optional<std::vector<int>> shape_, | ||||||
|          std::optional<Dtype> type, |          std::optional<Dtype> type, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         auto lower = to_array(lower_); |         auto lower = to_array(lower_); | ||||||
|         auto upper = to_array(upper_); |         auto upper = to_array(upper_); | ||||||
|         auto t = type.value_or(float32); |         auto t = type.value_or(float32); | ||||||
| @@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) { | |||||||
|       "gumbel", |       "gumbel", | ||||||
|       [](const std::vector<int>& shape, |       [](const std::vector<int>& shape, | ||||||
|          std::optional<Dtype> type, |          std::optional<Dtype> type, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         return gumbel(shape, type.value_or(float32), key, s); |         return gumbel(shape, type.value_or(float32), key, s); | ||||||
|       }, |       }, | ||||||
|       "shape"_a = std::vector<int>{}, |       "shape"_a = std::vector<int>{}, | ||||||
| @@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) { | |||||||
|          int axis, |          int axis, | ||||||
|          const std::optional<std::vector<int>> shape, |          const std::optional<std::vector<int>> shape, | ||||||
|          const std::optional<int> num_samples, |          const std::optional<int> num_samples, | ||||||
|          const std::optional<array>& key, |          const std::optional<array>& key_, | ||||||
|          StreamOrDevice s) { |          StreamOrDevice s) { | ||||||
|  |         auto key = key_ ? key_.value() : default_key().next(); | ||||||
|         if (shape.has_value() && num_samples.has_value()) { |         if (shape.has_value() && num_samples.has_value()) { | ||||||
|           throw std::invalid_argument( |           throw std::invalid_argument( | ||||||
|               "[categorical] At most one of shape or num_samples can be specified."); |               "[categorical] At most one of shape or num_samples can be specified."); | ||||||
| @@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) { | |||||||
|         Returns: |         Returns: | ||||||
|             array: The ``shape``-sized output array with type ``uint32``. |             array: The ``shape``-sized output array with type ``uint32``. | ||||||
|       )pbdoc"); |       )pbdoc"); | ||||||
|  |   // Register static Python object cleanup before the interpreter exits | ||||||
|  |   auto atexit = py::module_::import("atexit"); | ||||||
|  |   atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -135,6 +135,64 @@ py::object tree_map( | |||||||
|   }); |   }); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void tree_visit_update( | ||||||
|  |     py::object tree, | ||||||
|  |     std::function<py::object(py::handle)> visitor) { | ||||||
|  |   std::function<py::object(py::handle)> recurse; | ||||||
|  |   recurse = [&](py::handle subtree) { | ||||||
|  |     if (py::isinstance<py::list>(subtree)) { | ||||||
|  |       auto l = py::cast<py::list>(subtree); | ||||||
|  |       for (int i = 0; i < l.size(); ++i) { | ||||||
|  |         l[i] = recurse(l[i]); | ||||||
|  |       } | ||||||
|  |       return py::cast<py::object>(l); | ||||||
|  |     } else if (py::isinstance<py::tuple>(subtree)) { | ||||||
|  |       for (auto item : subtree) { | ||||||
|  |         recurse(item); | ||||||
|  |       } | ||||||
|  |       return py::cast<py::object>(subtree); | ||||||
|  |     } else if (py::isinstance<py::dict>(subtree)) { | ||||||
|  |       auto d = py::cast<py::dict>(subtree); | ||||||
|  |       for (auto item : d) { | ||||||
|  |         d[item.first] = recurse(item.second); | ||||||
|  |       } | ||||||
|  |       return py::cast<py::object>(d); | ||||||
|  |     } else if (py::isinstance<array>(subtree)) { | ||||||
|  |       return visitor(subtree); | ||||||
|  |     } else { | ||||||
|  |       return py::cast<py::object>(subtree); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |   recurse(tree); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Fill a pytree (recursive dict or list of dict or list) | ||||||
|  | // in place with the given arrays | ||||||
|  | // Non dict or list nodes are ignored | ||||||
|  | void tree_fill(py::object& tree, const std::vector<array>& values) { | ||||||
|  |   size_t index = 0; | ||||||
|  |   tree_visit_update( | ||||||
|  |       tree, [&](py::handle node) { return py::cast(values[index++]); }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Replace all the arrays from the src values with the dst values in the tree | ||||||
|  | void tree_replace( | ||||||
|  |     py::object& tree, | ||||||
|  |     const std::vector<array>& src, | ||||||
|  |     const std::vector<array>& dst) { | ||||||
|  |   std::unordered_map<uintptr_t, array> src_to_dst; | ||||||
|  |   for (int i = 0; i < src.size(); ++i) { | ||||||
|  |     src_to_dst.insert({src[i].id(), dst[i]}); | ||||||
|  |   } | ||||||
|  |   tree_visit_update(tree, [&](py::handle node) { | ||||||
|  |     auto arr = py::cast<array>(node); | ||||||
|  |     if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { | ||||||
|  |       return py::cast(it->second); | ||||||
|  |     } | ||||||
|  |     return py::cast(arr); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
| std::vector<array> tree_flatten(py::object tree, bool strict = true) { | std::vector<array> tree_flatten(py::object tree, bool strict = true) { | ||||||
|   std::vector<array> flat_tree; |   std::vector<array> flat_tree; | ||||||
|  |  | ||||||
| @@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() { | |||||||
| struct PyCompiledFun { | struct PyCompiledFun { | ||||||
|   py::function fun; |   py::function fun; | ||||||
|   size_t fun_id; |   size_t fun_id; | ||||||
|  |   py::object captured_inputs; | ||||||
|  |   py::object captured_outputs; | ||||||
|  |   size_t num_outputs{0}; | ||||||
|  |  | ||||||
|   PyCompiledFun(const py::function& fun) |   PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) | ||||||
|       : fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {} |       : fun(fun), | ||||||
|  |         fun_id(reinterpret_cast<size_t>(fun.ptr())), | ||||||
|  |         captured_inputs(inputs), | ||||||
|  |         captured_outputs(outputs) {} | ||||||
|  |  | ||||||
|   PyCompiledFun(const PyCompiledFun&) = delete; |   PyCompiledFun(const PyCompiledFun&) = delete; | ||||||
|   PyCompiledFun& operator=(const PyCompiledFun&) = delete; |   PyCompiledFun& operator=(const PyCompiledFun&) = delete; | ||||||
| @@ -505,23 +569,61 @@ struct PyCompiledFun { | |||||||
|   PyCompiledFun(PyCompiledFun&& other) |   PyCompiledFun(PyCompiledFun&& other) | ||||||
|       : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) { |       : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) { | ||||||
|     other.fun_id = 0; |     other.fun_id = 0; | ||||||
|  |     captured_inputs = std::move(other.captured_inputs); | ||||||
|  |     captured_outputs = std::move(other.captured_outputs); | ||||||
|  |     num_outputs = other.num_outputs; | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   py::object operator()(const py::args& args) { |   py::object operator()(const py::args& args) { | ||||||
|     auto compile_fun = [this, &args](const std::vector<array>& a) { |     auto compile_fun = [this, &args](const std::vector<array>& a) { | ||||||
|       // Call the python function and flatten the outputs |       // Put tracers into captured inputs | ||||||
|       auto [outputs, py_outputs] = tree_flatten_with_structure( |       std::vector<array> flat_in_captures; | ||||||
|           std::move(this->fun(*tree_unflatten(args, a))), true); |       std::vector<array> trace_captures; | ||||||
|  |       if (!py::isinstance<py::none>(captured_inputs)) { | ||||||
|  |         flat_in_captures = tree_flatten(captured_inputs, false); | ||||||
|  |         trace_captures.insert( | ||||||
|  |             trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); | ||||||
|  |         tree_fill(captured_inputs, trace_captures); | ||||||
|  |       } | ||||||
|  |  | ||||||
|       tree_cache().insert({this->fun_id, py_outputs}); |       auto [outputs, py_outputs] = tree_flatten_with_structure( | ||||||
|  |           std::move(fun(*tree_unflatten(args, a))), false); | ||||||
|  |  | ||||||
|  |       tree_cache().insert({fun_id, py_outputs}); | ||||||
|  |  | ||||||
|  |       num_outputs = outputs.size(); | ||||||
|  |       if (!py::isinstance<py::none>(captured_outputs)) { | ||||||
|  |         auto flat_out_captures = tree_flatten(captured_outputs, false); | ||||||
|  |         outputs.insert( | ||||||
|  |             outputs.end(), | ||||||
|  |             std::make_move_iterator(flat_out_captures.begin()), | ||||||
|  |             std::make_move_iterator(flat_out_captures.end())); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Replace tracers with originals in captured inputs | ||||||
|  |       if (!py::isinstance<py::none>(captured_inputs)) { | ||||||
|  |         tree_replace(captured_inputs, trace_captures, flat_in_captures); | ||||||
|  |       } | ||||||
|       return outputs; |       return outputs; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     // Inputs must be array or tree of arrays |     auto inputs = tree_flatten(args, false); | ||||||
|     auto inputs = tree_flatten(args, true); |     if (!py::isinstance<py::none>(captured_inputs)) { | ||||||
|  |       auto flat_in_captures = tree_flatten(captured_inputs, false); | ||||||
|  |       inputs.insert( | ||||||
|  |           inputs.end(), | ||||||
|  |           std::make_move_iterator(flat_in_captures.begin()), | ||||||
|  |           std::make_move_iterator(flat_in_captures.end())); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // Compile and call |     // Compile and call | ||||||
|     auto outputs = detail::compile(compile_fun, fun_id)(inputs); |     auto outputs = detail::compile(compile_fun, fun_id)(inputs); | ||||||
|  |     if (!py::isinstance<py::none>(captured_outputs)) { | ||||||
|  |       std::vector<array> captures( | ||||||
|  |           std::make_move_iterator(outputs.begin() + num_outputs), | ||||||
|  |           std::make_move_iterator(outputs.end())); | ||||||
|  |       tree_fill(captured_outputs, captures); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // Put the outputs back in the container |     // Put the outputs back in the container | ||||||
|     py::object py_outputs = tree_cache().at(fun_id); |     py::object py_outputs = tree_cache().at(fun_id); | ||||||
| @@ -534,6 +636,8 @@ struct PyCompiledFun { | |||||||
|     tree_cache().erase(fun_id); |     tree_cache().erase(fun_id); | ||||||
|     detail::compile_erase(fun_id); |     detail::compile_erase(fun_id); | ||||||
|     fun.release().dec_ref(); |     fun.release().dec_ref(); | ||||||
|  |     captured_inputs.release().dec_ref(); | ||||||
|  |     captured_outputs.release().dec_ref(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -601,7 +705,7 @@ void init_transforms(py::module_& m) { | |||||||
|   m.def( |   m.def( | ||||||
|       "eval", |       "eval", | ||||||
|       [](const py::args& args) { |       [](const py::args& args) { | ||||||
|         std::vector<array> arrays = tree_flatten(args); |         std::vector<array> arrays = tree_flatten(args, false); | ||||||
|         { |         { | ||||||
|           py::gil_scoped_release nogil; |           py::gil_scoped_release nogil; | ||||||
|           eval(arrays); |           eval(arrays); | ||||||
| @@ -615,8 +719,8 @@ void init_transforms(py::module_& m) { | |||||||
|         Args: |         Args: | ||||||
|             *args (arrays or trees of arrays): Each argument can be a single array |             *args (arrays or trees of arrays): Each argument can be a single array | ||||||
|               or a tree of arrays. If a tree is given the nodes can be a Python |               or a tree of arrays. If a tree is given the nodes can be a Python | ||||||
|               :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be |               :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not | ||||||
|               an :class:`array`. |               arrays are ignored. | ||||||
|       )pbdoc"); |       )pbdoc"); | ||||||
|   m.def( |   m.def( | ||||||
|       "jvp", |       "jvp", | ||||||
| @@ -859,10 +963,14 @@ void init_transforms(py::module_& m) { | |||||||
|       "file"_a); |       "file"_a); | ||||||
|   m.def( |   m.def( | ||||||
|       "compile", |       "compile", | ||||||
|       [](const py::function& fun) { |       [](const py::function& fun, | ||||||
|         return py::cpp_function(PyCompiledFun{fun}); |          const py::object& inputs, | ||||||
|  |          const py::object& outputs) { | ||||||
|  |         return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); | ||||||
|       }, |       }, | ||||||
|       "fun"_a, |       "fun"_a, | ||||||
|  |       "inputs"_a = std::nullopt, | ||||||
|  |       "outputs"_a = std::nullopt, | ||||||
|       R"pbdoc( |       R"pbdoc( | ||||||
|         compile(fun: function) -> function |         compile(fun: function) -> function | ||||||
|  |  | ||||||
| @@ -872,6 +980,16 @@ void init_transforms(py::module_& m) { | |||||||
|             fun (function): A function which takes a variable number of |             fun (function): A function which takes a variable number of | ||||||
|               :class:`array` or trees of :class:`array` and returns |               :class:`array` or trees of :class:`array` and returns | ||||||
|               a variable number of :class:`array` or trees of :class:`array`. |               a variable number of :class:`array` or trees of :class:`array`. | ||||||
|  |             inputs (list or dict, optional): These inputs will be captured during | ||||||
|  |               the function compilation along with the inputs to ``fun``. The ``inputs`` | ||||||
|  |               can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested | ||||||
|  |               lists, dictionaries, or arrays. Leaf nodes that are not | ||||||
|  |               :obj:`array` are ignored. Default: ``None`` | ||||||
|  |             outputs (list or dict, optional): These outputs will be captured and | ||||||
|  |               updated in a compiled function. The ``outputs`` can be a | ||||||
|  |               :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, | ||||||
|  |               dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. | ||||||
|  |               Default: ``None`` | ||||||
|  |  | ||||||
|         Returns: |         Returns: | ||||||
|             function: A compiled function which has the same input arguments |             function: A compiled function which has the same input arguments | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| import io | import io | ||||||
| import unittest | import unittest | ||||||
|  | from functools import partial | ||||||
|  |  | ||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
| import mlx_tests | import mlx_tests | ||||||
| @@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase): | |||||||
|         cdfdx = mx.grad(outer)(x) |         cdfdx = mx.grad(outer)(x) | ||||||
|         self.assertTrue(mx.allclose(dfdx, cdfdx)) |         self.assertTrue(mx.allclose(dfdx, cdfdx)) | ||||||
|  |  | ||||||
|  |     def test_compile_capture(self): | ||||||
|  |         # Test update captured state outside compiled function | ||||||
|  |         state = {"y": mx.array(2)} | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=state) | ||||||
|  |         def test_state(x): | ||||||
|  |             x = x + state["y"] | ||||||
|  |             return x | ||||||
|  |  | ||||||
|  |         test_state(mx.array(1)) | ||||||
|  |         # Check the state is unchanged | ||||||
|  |         self.assertEqual(state["y"], 2) | ||||||
|  |  | ||||||
|  |         # Check the udpated state is used | ||||||
|  |         state["y"] = mx.array(3) | ||||||
|  |         out = test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(out.item(), 4) | ||||||
|  |  | ||||||
|  |         # Capture list | ||||||
|  |         state = [mx.array(2)] | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=state) | ||||||
|  |         def test_state(x): | ||||||
|  |             x = x + state[0] | ||||||
|  |             return x | ||||||
|  |  | ||||||
|  |         out = test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(out.item(), 3) | ||||||
|  |         state[0] = mx.array(3) | ||||||
|  |         out = test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(out.item(), 4) | ||||||
|  |  | ||||||
|  |         # Capture tuple of list | ||||||
|  |         state = ([mx.array(2)],) | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=state) | ||||||
|  |         def test_state(x): | ||||||
|  |             x = x + state[0][0] | ||||||
|  |             return x | ||||||
|  |  | ||||||
|  |         out = test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(out.item(), 3) | ||||||
|  |         state[0][0] = mx.array(3) | ||||||
|  |         out = test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(out.item(), 4) | ||||||
|  |  | ||||||
|  |         # Test state updated inside compiled function | ||||||
|  |         state = {} | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, outputs=state) | ||||||
|  |         def test_state(x): | ||||||
|  |             state["y"] = x + 3 | ||||||
|  |             return mx.abs(x) | ||||||
|  |  | ||||||
|  |         test_state(mx.array(-1)) | ||||||
|  |         self.assertEqual(state["y"].item(), 2) | ||||||
|  |  | ||||||
|  |         # Test state changed inside compiled function | ||||||
|  |         # triggers recompile | ||||||
|  |         state = {} | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=state, outputs=state) | ||||||
|  |         def test_state(x): | ||||||
|  |             y = state.get("y", mx.array(0)) | ||||||
|  |             state["y"] = x + y | ||||||
|  |             return x + 2 * y | ||||||
|  |  | ||||||
|  |         test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(state["y"].item(), 1) | ||||||
|  |         test_state(mx.array(1)) | ||||||
|  |         self.assertEqual(state["y"].item(), 2) | ||||||
|  |  | ||||||
|  |     def test_compile_rng(self): | ||||||
|  |         @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) | ||||||
|  |         def fun(): | ||||||
|  |             return mx.random.uniform(shape=(10, 10)) | ||||||
|  |  | ||||||
|  |         self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): | |||||||
|         y = dfun_dx(mx.array(1.0)) |         y = dfun_dx(mx.array(1.0)) | ||||||
|         self.assertEqual(y.item(), 6.0) |         self.assertEqual(y.item(), 6.0) | ||||||
|  |  | ||||||
|  |     def test_eval_mixed(self): | ||||||
|  |         x = mx.array(1) + 1 + 1 | ||||||
|  |         y = 0 | ||||||
|  |         z = "hello" | ||||||
|  |         state = [x, y, z] | ||||||
|  |         mx.eval(state) | ||||||
|  |         self.assertEqual(x.item(), 3) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): | |||||||
|                 ] |                 ] | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def test_module_state(self): | ||||||
|  |         m = nn.Linear(10, 1) | ||||||
|  |         m.state["hello"] = "world" | ||||||
|  |         self.assertEqual(m.state["hello"], "world") | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestLayers(mlx_tests.MLXTestCase): | class TestLayers(mlx_tests.MLXTestCase): | ||||||
|     def test_identity(self): |     def test_identity(self): | ||||||
|   | |||||||
| @@ -2,47 +2,209 @@ | |||||||
|  |  | ||||||
| import inspect | import inspect | ||||||
| import unittest | import unittest | ||||||
|  | from functools import partial | ||||||
|  |  | ||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
|  | import mlx.nn as nn | ||||||
| import mlx.optimizers as opt | import mlx.optimizers as opt | ||||||
| import mlx.utils | import mlx.utils | ||||||
| import mlx_tests | import mlx_tests | ||||||
|  | from mlx.utils import tree_flatten, tree_map | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_all_optimizers(): | def get_all_optimizers(): | ||||||
|     classes = dict() |     classes = dict() | ||||||
|     for name, obj in inspect.getmembers(opt): |     for name, obj in inspect.getmembers(opt): | ||||||
|         if inspect.isclass(obj): |         if inspect.isclass(obj): | ||||||
|             if obj.__name__ not in ["OptimizerState", "Optimizer"]: |             if obj.__name__ not in ["Optimizer"]: | ||||||
|                 classes[name] = obj |                 classes[name] = obj | ||||||
|     return classes |     return classes | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def tree_equal(fn, *args): | ||||||
|  |     return all(v for _, v in tree_flatten(tree_map(fn, *args))) | ||||||
|  |  | ||||||
|  |  | ||||||
| optimizers_dict = get_all_optimizers() | optimizers_dict = get_all_optimizers() | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestOptimizers(mlx_tests.MLXTestCase): | class TestOptimizers(mlx_tests.MLXTestCase): | ||||||
|  |     def test_optimizer_state(self): | ||||||
|  |         optim = opt.SGD(0.1) | ||||||
|  |         optim.state["hello"] = "world" | ||||||
|  |         self.assertEqual(optim.state["hello"], "world") | ||||||
|  |  | ||||||
|  |         optim.state = {0: 1} | ||||||
|  |         self.assertEqual(optim.state, {0: 1}) | ||||||
|  |  | ||||||
|     def test_optimizers(self): |     def test_optimizers(self): | ||||||
|         params = { |         params = { | ||||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|             "second": mx.zeros((1,)), |             "second": mx.zeros((1,)), | ||||||
|         } |         } | ||||||
|         grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|         for optim_class in optimizers_dict.values(): |         for optim_class in optimizers_dict.values(): | ||||||
|             optim = optim_class(0.1) |             optim = optim_class(0.1) | ||||||
|             update = optim.apply_gradients(grads, params) |             update = optim.apply_gradients(grads, params) | ||||||
|             mx.eval(update) |             mx.eval(update) | ||||||
|             equal_shape = mlx.utils.tree_map( |             equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) | ||||||
|                 lambda x, y: x.shape == y.shape, params, update |  | ||||||
|             ) |  | ||||||
|             all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) |             all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) | ||||||
|             self.assertTrue(all_equal) |             self.assertTrue(all_equal) | ||||||
|  |  | ||||||
|  |     def test_types_conserved(self): | ||||||
|  |         params = {"w": mx.ones((5, 5), mx.float16)} | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |         for optim_class in optimizers_dict.values(): | ||||||
|  |             optim = optim_class(0.1) | ||||||
|  |             update = optim.apply_gradients(grads, params) | ||||||
|  |             self.assertEqual(update["w"].dtype, mx.float16) | ||||||
|  |  | ||||||
|  |     def test_sgd(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||||
|  |         optim.init(params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Implicit init | ||||||
|  |         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||||
|  |         optim.apply_gradients(grads, params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_rmsprop(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         optim = opt.RMSprop(learning_rate=1e-2) | ||||||
|  |         optim.init(params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Implicit init | ||||||
|  |         alpha = 0.99 | ||||||
|  |         optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) | ||||||
|  |         optim.apply_gradients(grads, params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_adagrad(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         optim = opt.Adagrad(learning_rate=1e-2) | ||||||
|  |         optim.init(params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_adadelta(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         optim = opt.AdaDelta(learning_rate=1e-2) | ||||||
|  |         optim.init(params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_adam(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: | ||||||
|  |             optim = optimizer(learning_rate=1e-2) | ||||||
|  |             optim.init(params) | ||||||
|  |             self.assertTrue( | ||||||
|  |                 tree_equal( | ||||||
|  |                     lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||||
|  |                     params, | ||||||
|  |                     optim.state, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             self.assertTrue( | ||||||
|  |                 tree_equal( | ||||||
|  |                     lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||||
|  |                     params, | ||||||
|  |                     optim.state, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def test_lion(self): | ||||||
|  |         params = { | ||||||
|  |             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||||
|  |             "second": mx.zeros((1,)), | ||||||
|  |         } | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |  | ||||||
|  |         # Explicit init | ||||||
|  |         optim = opt.Lion(learning_rate=1e-2) | ||||||
|  |         optim.init(params) | ||||||
|  |         self.assertTrue( | ||||||
|  |             tree_equal( | ||||||
|  |                 lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||||
|  |                 params, | ||||||
|  |                 optim.state, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def test_adafactor(self): |     def test_adafactor(self): | ||||||
|         x = mx.zeros((5, 5)) |         x = mx.zeros((5, 5)) | ||||||
|         grad = mx.ones_like(x) |         grad = mx.ones_like(x) | ||||||
|         optimizer = opt.Adafactor() |         optimizer = opt.Adafactor() | ||||||
|  |         optimizer.init(x) | ||||||
|         for _ in range(2): |         for _ in range(2): | ||||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) |             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||||
|             self.assertEqual(xp.dtype, x.dtype) |             self.assertEqual(xp.dtype, x.dtype) | ||||||
| @@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase): | |||||||
|         x = mx.zeros((5, 5), mx.float16) |         x = mx.zeros((5, 5), mx.float16) | ||||||
|         grad = mx.ones_like(x) |         grad = mx.ones_like(x) | ||||||
|         optimizer = opt.Adafactor() |         optimizer = opt.Adafactor() | ||||||
|  |         optimizer.init(x) | ||||||
|         for _ in range(2): |         for _ in range(2): | ||||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) |             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||||
|             self.assertEqual(xp.dtype, x.dtype) |             self.assertEqual(xp.dtype, x.dtype) | ||||||
|             self.assertEqual(xp.shape, x.shape) |             self.assertEqual(xp.shape, x.shape) | ||||||
|         self.assertEqual(optimizer.state["step"], 2) |         self.assertEqual(optimizer.state["step"], 2) | ||||||
|  |  | ||||||
|  |     def test_compiled_optimizer(self): | ||||||
|  |         model = nn.Linear(10, 10) | ||||||
|  |         x = mx.random.uniform(shape=(2, 10)) | ||||||
|  |         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||||
|  |  | ||||||
|  |         orig_params = model.parameters() | ||||||
|  |  | ||||||
|  |         def loss(model, x): | ||||||
|  |             return model(x).sum() | ||||||
|  |  | ||||||
|  |         # Uncompiled version | ||||||
|  |         def step(x): | ||||||
|  |             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||||
|  |             optim.update(model, grad) | ||||||
|  |  | ||||||
|  |         step(x) | ||||||
|  |         uncompiled_params = model.parameters() | ||||||
|  |  | ||||||
|  |         # Pure version | ||||||
|  |         def loss(params, x): | ||||||
|  |             model.update(params) | ||||||
|  |             return model(x).sum() | ||||||
|  |  | ||||||
|  |         model.update(orig_params) | ||||||
|  |         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||||
|  |  | ||||||
|  |         @mx.compile | ||||||
|  |         def step(params, opt_state, x): | ||||||
|  |             grad = mx.grad(loss)(params, x) | ||||||
|  |             optim.state = opt_state | ||||||
|  |             params = optim.apply_gradients(grad, params) | ||||||
|  |             return params, optim.state | ||||||
|  |  | ||||||
|  |         optim.init(model.parameters()) | ||||||
|  |         pure_params, _ = step(model.parameters(), optim.state, x) | ||||||
|  |         self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) | ||||||
|  |         self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) | ||||||
|  |  | ||||||
|  |         # Impure version | ||||||
|  |         def loss(model, x): | ||||||
|  |             return model(x).sum() | ||||||
|  |  | ||||||
|  |         model.update(orig_params) | ||||||
|  |         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||||
|  |         state = [model.state, optim.state] | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=state, outputs=state) | ||||||
|  |         def step(x): | ||||||
|  |             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||||
|  |             optim.update(model, grad) | ||||||
|  |  | ||||||
|  |         step(x) | ||||||
|  |         impure_params = model.parameters() | ||||||
|  |         self.assertTrue( | ||||||
|  |             mx.allclose(impure_params["weight"], uncompiled_params["weight"]) | ||||||
|  |         ) | ||||||
|  |         self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) | ||||||
|  |  | ||||||
|  |     def test_update_lr_compiled(self): | ||||||
|  |         params = {"w": mx.ones((5, 5))} | ||||||
|  |         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||||
|  |         optim = opt.SGD(-1.0) | ||||||
|  |  | ||||||
|  |         @partial(mx.compile, inputs=optim.state) | ||||||
|  |         def update(grads): | ||||||
|  |             return optim.apply_gradients(grads, params) | ||||||
|  |  | ||||||
|  |         result = update(grads) | ||||||
|  |         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) | ||||||
|  |         optim.learning_rate = -2.0 | ||||||
|  |         result = update(grads) | ||||||
|  |         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun