mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -66,6 +66,19 @@ class Module(dict): | ||||
|         """Boolean indicating if the model is in training mode.""" | ||||
|         return self._training | ||||
|  | ||||
|     @property | ||||
|     def state(self): | ||||
|         """The module's state dictionary | ||||
|  | ||||
|         The module's state dictionary contains any attribute set on the | ||||
|         module including parameters in :meth:`Module.parameters` | ||||
|  | ||||
|         Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is | ||||
|         a reference to the module's state. Updates to it will be reflected in | ||||
|         the original module. | ||||
|         """ | ||||
|         return self | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return "" | ||||
|  | ||||
|   | ||||
| @@ -7,39 +7,14 @@ import mlx.core as mx | ||||
| from mlx.utils import tree_map | ||||
|  | ||||
|  | ||||
| class OptimizerState(dict): | ||||
|     """The optimizer state implements a recursively defined | ||||
|     :class:`collections.defaultdict`, namely a missing key in an optimizer | ||||
|     state is an :class:`OptimizerState`. | ||||
|  | ||||
|     .. note:: | ||||
|        :meth:`OptimizerState.get` in contrast to a normal dictionary also sets | ||||
|        the key to the ``default`` value if the ``key`` was not present in the | ||||
|        dictionary. | ||||
|     """ | ||||
|  | ||||
|     def __getitem__(self, key): | ||||
|         if key not in self: | ||||
|             self[key] = OptimizerState() | ||||
|         return super().__getitem__(key) | ||||
|  | ||||
|     def get(self, key, default): | ||||
|         """If ``key`` doesn't exist set its value to ``default`` and then return it.""" | ||||
|         if key not in self: | ||||
|             self[key] = default | ||||
|         return super().__getitem__(key) | ||||
|  | ||||
|  | ||||
| class Optimizer: | ||||
|     """The base class for all optimizers. It allows us to implement an | ||||
|     optimizer on a per-parameter basis and apply it to a parameter tree. | ||||
|  | ||||
|     Attributes: | ||||
|         state (OptimizerState): It holds the optimizer's state dictionary. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.state = OptimizerState() | ||||
|         self._initialized = False | ||||
|         self._state = {} | ||||
|  | ||||
|     def update(self, model: "mlx.nn.Module", gradients: dict): | ||||
|         """Apply the gradients to the parameters of the model and update the | ||||
| @@ -52,7 +27,41 @@ class Optimizer: | ||||
|         """ | ||||
|         model.update(self.apply_gradients(gradients, model)) | ||||
|  | ||||
|     def apply_gradients(self, gradients: dict, model: dict): | ||||
|     def init(self, parameters: dict): | ||||
|         """Initialize the optimizer's state | ||||
|  | ||||
|         This function can be used to initialize optimizers which have state | ||||
|         (like momentum in :class:`SGD`). Using this method is optional as the | ||||
|         optimizer will initialize itself if the state is not yet set. However, | ||||
|         there are some cases where explicit initialization is useful in order | ||||
|         to have access to the :attr:`Optimizer.state` before the first call to | ||||
|         :meth:`Optimizer.update`. | ||||
|  | ||||
|         Args: | ||||
|             model (dict): A Python tree of parameters. | ||||
|  | ||||
|         Example: | ||||
|             >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) | ||||
|             >>> model = nn.Linear(2, 2) | ||||
|             >>> optimizer.init(model.trainable_parameters()) | ||||
|             >>> optimizer.state | ||||
|             {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], | ||||
|                    [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} | ||||
|         """ | ||||
|         self._state.update(tree_map(lambda x: {}, parameters)) | ||||
|         tree_map(self.init_single, parameters, self._state) | ||||
|         self._initialized = True | ||||
|  | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """To be extended by the children classes to implement each optimizer's | ||||
|         state initialization. | ||||
|  | ||||
|         Args: | ||||
|             parameter (mx.array): A single parameter that will be optimized. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def apply_gradients(self, gradients: dict, parameters: dict): | ||||
|         """Apply the gradients to the parameters and return the updated parameters. | ||||
|  | ||||
|         Can be used to update a model via | ||||
| @@ -61,19 +70,41 @@ class Optimizer: | ||||
|  | ||||
|         Args: | ||||
|             gradients (dict): A Python tree of gradients. | ||||
|             model (dict): A Python tree of parameters. It can be a superset of | ||||
|                           the gradients. In that case the returned python tree | ||||
|                           will be of the same structure as the gradients. | ||||
|             parameters (dict): A Python tree of parameters. It can be a | ||||
|               superset of the gradients. In that case the returned python | ||||
|               tree will be of the same structure as the gradients. | ||||
|         """ | ||||
|         return tree_map(self.apply_single, gradients, model, self.state) | ||||
|         if not self._initialized: | ||||
|             self.init(gradients) | ||||
|         return tree_map(self.apply_single, gradients, parameters, self.state) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|         """To be extended by the children classes to implement each optimizer's | ||||
|         update.""" | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """To be extended by derived classes to implement the optimizer's update. | ||||
|  | ||||
|         Args: | ||||
|             gradient (mx.array): The ``parameter`` gradient. | ||||
|             parameter (mx.array): The ``parameter`` to update. | ||||
|             state (dict): The optimizer's state. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @property | ||||
|     def state(self): | ||||
|         """The optimizer's state dictionary.""" | ||||
|         return self._state | ||||
|  | ||||
|     @state.setter | ||||
|     def state(self, state: dict): | ||||
|         self._state = state | ||||
|  | ||||
|     @property | ||||
|     def learning_rate(self): | ||||
|         return self.state["learning_rate"] | ||||
|  | ||||
|     @learning_rate.setter | ||||
|     def learning_rate(self, learning_rate: mx.array): | ||||
|         self.state["learning_rate"] = mx.array(learning_rate) | ||||
|  | ||||
|  | ||||
| class SGD(Optimizer): | ||||
|     r"""The stochastic gradient descent optimizer. | ||||
| @@ -113,9 +144,11 @@ class SGD(Optimizer): | ||||
|         self.dampening = dampening | ||||
|         self.nesterov = nesterov | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the SGD parameter update and stores :math:`v` in the | ||||
|         optimizer state.""" | ||||
|  | ||||
| @@ -123,24 +156,21 @@ class SGD(Optimizer): | ||||
|             gradient += self.weight_decay * parameter | ||||
|  | ||||
|         if self.momentum <= 0: | ||||
|             return parameter - self.learning_rate * gradient | ||||
|             return parameter - self.learning_rate.astype(gradient.dtype) * gradient | ||||
|  | ||||
|         v = self.momentum * state.get("v") | ||||
|         if self.dampening > 0: | ||||
|             v = ( | ||||
|                 state.get("v", (self.dampening / self.momentum) * gradient) | ||||
|                 * self.momentum | ||||
|             ) | ||||
|             v += (1 - self.dampening) * gradient | ||||
|         else: | ||||
|             v = state.get("v", mx.zeros_like(gradient)) * self.momentum | ||||
|             v += gradient | ||||
|  | ||||
|         if self.nesterov: | ||||
|             update = gradient + self.momentum * v | ||||
|         else: | ||||
|             update = v | ||||
|  | ||||
|         state["v"] = v | ||||
|         return parameter - self.learning_rate * update | ||||
|         return parameter - self.learning_rate.astype(gradient.dtype) * update | ||||
|  | ||||
|  | ||||
| class RMSprop(Optimizer): | ||||
| @@ -177,15 +207,17 @@ class RMSprop(Optimizer): | ||||
|                 f"RMSprop epsilon should be >0, {self.eps} was provided instead" | ||||
|             ) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         alpha = self.alpha | ||||
|         eps = self.eps | ||||
|  | ||||
|         v = state.get("v", mx.zeros_like(gradient)) | ||||
|         v = state["v"] | ||||
|         v = alpha * v + (1 - alpha) * mx.square(gradient) | ||||
|         state["v"] = v | ||||
|  | ||||
| @@ -222,16 +254,17 @@ class Adagrad(Optimizer): | ||||
|                 f"Adagrad epsilon should be >0, {self.eps} was provided instead" | ||||
|             ) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the Adagrad parameter update and stores :math:`v` in the | ||||
|         optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         eps = self.eps | ||||
|  | ||||
|         v = state.get("v", mx.zeros_like(gradient)) | ||||
|         v = v + mx.square(gradient) | ||||
|         v = state["v"] + mx.square(gradient) | ||||
|         state["v"] = v | ||||
|  | ||||
|         return parameter - lr * gradient / (mx.sqrt(v) + eps) | ||||
| @@ -274,17 +307,20 @@ class AdaDelta(Optimizer): | ||||
|                 f"AdaDelta epsilon should be >0, {self.eps} was provided instead" | ||||
|             ) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|         state["u"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the AdaDelta parameter update and stores :math:`v` and | ||||
|         :math:`u` in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         rho = self.rho | ||||
|         eps = self.eps | ||||
|  | ||||
|         v = state.get("v", mx.zeros_like(gradient)) | ||||
|         u = state.get("u", mx.zeros_like(gradient)) | ||||
|         v = state["v"] | ||||
|         u = state["u"] | ||||
|  | ||||
|         v = rho * v + (1 - rho) * mx.square(gradient) | ||||
|         d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient | ||||
| @@ -329,17 +365,20 @@ class Adam(Optimizer): | ||||
|         self.betas = betas | ||||
|         self.eps = eps | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["m"] = mx.zeros_like(parameter) | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the Adam parameter update and stores :math:`v` and | ||||
|         :math:`m` in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         b1, b2 = self.betas | ||||
|         eps = self.eps | ||||
|  | ||||
|         m = state.get("m", gradient) | ||||
|         v = state.get("v", mx.square(gradient)) | ||||
|         m = state["m"] | ||||
|         v = state["v"] | ||||
|         m = b1 * m + (1 - b1) * gradient | ||||
|         v = b2 * v + (1 - b2) * mx.square(gradient) | ||||
|         state["m"] = m | ||||
| @@ -385,15 +424,14 @@ class AdamW(Adam): | ||||
|         super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) | ||||
|         self.weight_decay = weight_decay | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the AdamW parameter update by modifying the parameters | ||||
|         passed into Adam. | ||||
|         """ | ||||
|  | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         return super().apply_single( | ||||
|             gradient, parameter * (1 - self.learning_rate * self.weight_decay), state | ||||
|             gradient, parameter * (1 - lr * self.weight_decay), state | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -430,17 +468,20 @@ class Adamax(Adam): | ||||
|                 f"Epsilon value should be >=0, {self.eps} was provided instead" | ||||
|             ) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["m"] = mx.zeros_like(parameter) | ||||
|         state["v"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the Adamax parameter update and stores :math:`v` and | ||||
|         :math:`m` in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         b1, b2 = self.betas | ||||
|         eps = self.eps | ||||
|  | ||||
|         m = state.get("m", mx.zeros_like(gradient)) | ||||
|         v = state.get("v", mx.zeros_like(gradient)) | ||||
|         m = state["m"] | ||||
|         v = state["v"] | ||||
|  | ||||
|         m = b1 * m + (1 - b1) * gradient | ||||
|         v = mx.maximum(b2 * v, mx.abs(gradient)) | ||||
| @@ -489,16 +530,18 @@ class Lion(Optimizer): | ||||
|         self.betas = betas | ||||
|         self.weight_decay = weight_decay | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["m"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the Lion parameter update and stores :math:`m` | ||||
|         in the optimizer state.""" | ||||
|         lr = self.learning_rate | ||||
|         lr = self.learning_rate.astype(gradient.dtype) | ||||
|         b1, b2 = self.betas | ||||
|         weight_decay = self.weight_decay | ||||
|  | ||||
|         m = state.get("m", gradient) | ||||
|         m = state["m"] | ||||
|         c = b1 * m + (1 - b1) * gradient | ||||
|         state["m"] = b2 * m + (1 - b2) * gradient | ||||
|         if weight_decay > 0: | ||||
| @@ -552,7 +595,8 @@ class Adafactor(Optimizer): | ||||
|         warmup_init: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.learning_rate = learning_rate | ||||
|         if learning_rate is not None: | ||||
|             self.learning_rate = learning_rate | ||||
|         self.eps = eps | ||||
|         self.clip_threshold = clip_threshold | ||||
|         self.decay_rate = decay_rate | ||||
| @@ -562,14 +606,29 @@ class Adafactor(Optimizer): | ||||
|         self.relative_step = relative_step | ||||
|         self.warmup_init = warmup_init | ||||
|  | ||||
|     def init_single(self, parameter: mx.array, state: dict): | ||||
|         """Initialize optimizer state""" | ||||
|         state["step"] = 0 | ||||
|         if parameter.ndim >= 2: | ||||
|             shape = parameter.shape | ||||
|             dtype = parameter.dtype | ||||
|             state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype) | ||||
|             state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype) | ||||
|         else: | ||||
|             state["exp_avg_sq"] = mx.zeros_like(parameter) | ||||
|  | ||||
|         if self.beta_1 is not None: | ||||
|             state["exp_avg"] = mx.zeros_like(parameter) | ||||
|  | ||||
|     def _compute_rms(self, inputs): | ||||
|         return mx.sqrt(mx.mean(mx.square(inputs))) | ||||
|  | ||||
|     def _compute_learning_rate(self, step, parameter_rms): | ||||
|         relative_step_size = self.learning_rate | ||||
|         if self.relative_step: | ||||
|             min_step = 1e-6 * step if self.warmup_init else 1e-2 | ||||
|             relative_step_size = min(min_step, 1 / math.sqrt(step)) | ||||
|         else: | ||||
|             relative_step_size = self.learning_rate.astype(parameter_rms) | ||||
|  | ||||
|         parameter_scale = 1.0 | ||||
|         if self.scale_parameter: | ||||
| @@ -585,13 +644,11 @@ class Adafactor(Optimizer): | ||||
|             mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) | ||||
|         ) | ||||
|  | ||||
|     def apply_single( | ||||
|         self, gradient: mx.array, parameter: mx.array, state: OptimizerState | ||||
|     ): | ||||
|     def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): | ||||
|         """Performs the Adafactor parameter and state update.""" | ||||
|         gradient_shape = gradient.shape | ||||
|         factored = len(gradient_shape) >= 2 | ||||
|         step = state.get("step", 0) + 1 | ||||
|         factored = gradient.ndim >= 2 | ||||
|  | ||||
|         step = state["step"] + 1 | ||||
|         state["step"] = step | ||||
|         use_first_moment = self.beta_1 is not None | ||||
|  | ||||
| @@ -601,15 +658,8 @@ class Adafactor(Optimizer): | ||||
|         update = mx.square(gradient) + self.eps[0] | ||||
|  | ||||
|         if factored: | ||||
|             exp_avg_sq_row = state.get( | ||||
|                 "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) | ||||
|             ) | ||||
|             exp_avg_sq_col = state.get( | ||||
|                 "exp_avg_sq_col", | ||||
|                 mx.zeros( | ||||
|                     gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype | ||||
|                 ), | ||||
|             ) | ||||
|             exp_avg_sq_row = state["exp_avg_sq_row"] | ||||
|             exp_avg_sq_col = state["exp_avg_sq_col"] | ||||
|             exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( | ||||
|                 (1 - beta_2) * mx.mean(update, axis=-1) | ||||
|             ) | ||||
| @@ -621,7 +671,7 @@ class Adafactor(Optimizer): | ||||
|             update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) | ||||
|             update = update * gradient | ||||
|         else: | ||||
|             exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) | ||||
|             exp_avg_sq = state["exp_avg_sq"] | ||||
|             exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) | ||||
|             state["exp_avg_sq"] = exp_avg_sq | ||||
|             update = mx.rsqrt(exp_avg_sq) * gradient | ||||
| @@ -632,7 +682,7 @@ class Adafactor(Optimizer): | ||||
|         update = learning_rate * update | ||||
|  | ||||
|         if use_first_moment: | ||||
|             exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) | ||||
|             exp_avg = state["exp_avg"] | ||||
|             exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) | ||||
|             state["exp_avg"] = exp_avg | ||||
|             update = exp_avg | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
| #include <chrono> | ||||
|  | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| @@ -13,13 +14,55 @@ using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
| using namespace mlx::core::random; | ||||
|  | ||||
| class PyKeySequence { | ||||
|  public: | ||||
|   explicit PyKeySequence(uint64_t seed) { | ||||
|     state_.append(key(seed)); | ||||
|   } | ||||
|  | ||||
|   void seed(uint64_t seed) { | ||||
|     state_[0] = key(seed); | ||||
|   } | ||||
|  | ||||
|   array next() { | ||||
|     auto out = split(py::cast<array>(state_[0])); | ||||
|     state_[0] = out.first; | ||||
|     return out.second; | ||||
|   } | ||||
|  | ||||
|   py::list state() { | ||||
|     return state_; | ||||
|   } | ||||
|  | ||||
|   void release() { | ||||
|     py::gil_scoped_acquire gil; | ||||
|     state_.release().dec_ref(); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   py::list state_; | ||||
| }; | ||||
|  | ||||
| PyKeySequence& default_key() { | ||||
|   auto get_current_time_seed = []() { | ||||
|     auto now = std::chrono::system_clock::now(); | ||||
|     return std::chrono::duration_cast<std::chrono::milliseconds>( | ||||
|                now.time_since_epoch()) | ||||
|         .count(); | ||||
|   }; | ||||
|   static PyKeySequence ks(get_current_time_seed()); | ||||
|   return ks; | ||||
| } | ||||
|  | ||||
| void init_random(py::module_& parent_module) { | ||||
|   auto m = parent_module.def_submodule( | ||||
|       "random", | ||||
|       "mlx.core.random: functionality related to random number generation"); | ||||
|  | ||||
|   m.attr("state") = default_key().state(); | ||||
|   m.def( | ||||
|       "seed", | ||||
|       &seed, | ||||
|       [](uint64_t seed) { default_key().seed(seed); }, | ||||
|       "seed"_a, | ||||
|       R"pbdoc( | ||||
|         Seed the global PRNG. | ||||
| @@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) { | ||||
|          const ScalarOrArray& high, | ||||
|          const std::vector<int>& shape, | ||||
|          std::optional<Dtype> type, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         return uniform( | ||||
|             to_array(low), | ||||
|             to_array(high), | ||||
| @@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) { | ||||
|          std::optional<Dtype> type, | ||||
|          float loc, | ||||
|          float scale, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         return normal(shape, type.value_or(float32), loc, scale, key, s); | ||||
|       }, | ||||
|  | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a = std::optional{float32}, | ||||
|       "loc"_a = 0.0, | ||||
| @@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) { | ||||
|          const ScalarOrArray& high, | ||||
|          const std::vector<int>& shape, | ||||
|          std::optional<Dtype> type, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         return randint( | ||||
|             to_array(low), to_array(high), shape, type.value_or(int32), key, s); | ||||
|       }, | ||||
| @@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) { | ||||
|       "bernoulli", | ||||
|       [](const ScalarOrArray& p_, | ||||
|          const std::optional<std::vector<int>> shape, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         auto p = to_array(p_); | ||||
|         if (shape.has_value()) { | ||||
|           return bernoulli(p, shape.value(), key, s); | ||||
| @@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) { | ||||
|          const ScalarOrArray& upper_, | ||||
|          const std::optional<std::vector<int>> shape_, | ||||
|          std::optional<Dtype> type, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         auto lower = to_array(lower_); | ||||
|         auto upper = to_array(upper_); | ||||
|         auto t = type.value_or(float32); | ||||
| @@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) { | ||||
|       "gumbel", | ||||
|       [](const std::vector<int>& shape, | ||||
|          std::optional<Dtype> type, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         return gumbel(shape, type.value_or(float32), key, s); | ||||
|       }, | ||||
|       "shape"_a = std::vector<int>{}, | ||||
| @@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) { | ||||
|          int axis, | ||||
|          const std::optional<std::vector<int>> shape, | ||||
|          const std::optional<int> num_samples, | ||||
|          const std::optional<array>& key, | ||||
|          const std::optional<array>& key_, | ||||
|          StreamOrDevice s) { | ||||
|         auto key = key_ ? key_.value() : default_key().next(); | ||||
|         if (shape.has_value() && num_samples.has_value()) { | ||||
|           throw std::invalid_argument( | ||||
|               "[categorical] At most one of shape or num_samples can be specified."); | ||||
| @@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) { | ||||
|         Returns: | ||||
|             array: The ``shape``-sized output array with type ``uint32``. | ||||
|       )pbdoc"); | ||||
|   // Register static Python object cleanup before the interpreter exits | ||||
|   auto atexit = py::module_::import("atexit"); | ||||
|   atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); | ||||
| } | ||||
|   | ||||
| @@ -135,6 +135,64 @@ py::object tree_map( | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void tree_visit_update( | ||||
|     py::object tree, | ||||
|     std::function<py::object(py::handle)> visitor) { | ||||
|   std::function<py::object(py::handle)> recurse; | ||||
|   recurse = [&](py::handle subtree) { | ||||
|     if (py::isinstance<py::list>(subtree)) { | ||||
|       auto l = py::cast<py::list>(subtree); | ||||
|       for (int i = 0; i < l.size(); ++i) { | ||||
|         l[i] = recurse(l[i]); | ||||
|       } | ||||
|       return py::cast<py::object>(l); | ||||
|     } else if (py::isinstance<py::tuple>(subtree)) { | ||||
|       for (auto item : subtree) { | ||||
|         recurse(item); | ||||
|       } | ||||
|       return py::cast<py::object>(subtree); | ||||
|     } else if (py::isinstance<py::dict>(subtree)) { | ||||
|       auto d = py::cast<py::dict>(subtree); | ||||
|       for (auto item : d) { | ||||
|         d[item.first] = recurse(item.second); | ||||
|       } | ||||
|       return py::cast<py::object>(d); | ||||
|     } else if (py::isinstance<array>(subtree)) { | ||||
|       return visitor(subtree); | ||||
|     } else { | ||||
|       return py::cast<py::object>(subtree); | ||||
|     } | ||||
|   }; | ||||
|   recurse(tree); | ||||
| } | ||||
|  | ||||
| // Fill a pytree (recursive dict or list of dict or list) | ||||
| // in place with the given arrays | ||||
| // Non dict or list nodes are ignored | ||||
| void tree_fill(py::object& tree, const std::vector<array>& values) { | ||||
|   size_t index = 0; | ||||
|   tree_visit_update( | ||||
|       tree, [&](py::handle node) { return py::cast(values[index++]); }); | ||||
| } | ||||
|  | ||||
| // Replace all the arrays from the src values with the dst values in the tree | ||||
| void tree_replace( | ||||
|     py::object& tree, | ||||
|     const std::vector<array>& src, | ||||
|     const std::vector<array>& dst) { | ||||
|   std::unordered_map<uintptr_t, array> src_to_dst; | ||||
|   for (int i = 0; i < src.size(); ++i) { | ||||
|     src_to_dst.insert({src[i].id(), dst[i]}); | ||||
|   } | ||||
|   tree_visit_update(tree, [&](py::handle node) { | ||||
|     auto arr = py::cast<array>(node); | ||||
|     if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { | ||||
|       return py::cast(it->second); | ||||
|     } | ||||
|     return py::cast(arr); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| std::vector<array> tree_flatten(py::object tree, bool strict = true) { | ||||
|   std::vector<array> flat_tree; | ||||
|  | ||||
| @@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() { | ||||
| struct PyCompiledFun { | ||||
|   py::function fun; | ||||
|   size_t fun_id; | ||||
|   py::object captured_inputs; | ||||
|   py::object captured_outputs; | ||||
|   size_t num_outputs{0}; | ||||
|  | ||||
|   PyCompiledFun(const py::function& fun) | ||||
|       : fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {} | ||||
|   PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) | ||||
|       : fun(fun), | ||||
|         fun_id(reinterpret_cast<size_t>(fun.ptr())), | ||||
|         captured_inputs(inputs), | ||||
|         captured_outputs(outputs) {} | ||||
|  | ||||
|   PyCompiledFun(const PyCompiledFun&) = delete; | ||||
|   PyCompiledFun& operator=(const PyCompiledFun&) = delete; | ||||
| @@ -505,23 +569,61 @@ struct PyCompiledFun { | ||||
|   PyCompiledFun(PyCompiledFun&& other) | ||||
|       : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) { | ||||
|     other.fun_id = 0; | ||||
|     captured_inputs = std::move(other.captured_inputs); | ||||
|     captured_outputs = std::move(other.captured_outputs); | ||||
|     num_outputs = other.num_outputs; | ||||
|   }; | ||||
|  | ||||
|   py::object operator()(const py::args& args) { | ||||
|     auto compile_fun = [this, &args](const std::vector<array>& a) { | ||||
|       // Call the python function and flatten the outputs | ||||
|       auto [outputs, py_outputs] = tree_flatten_with_structure( | ||||
|           std::move(this->fun(*tree_unflatten(args, a))), true); | ||||
|       // Put tracers into captured inputs | ||||
|       std::vector<array> flat_in_captures; | ||||
|       std::vector<array> trace_captures; | ||||
|       if (!py::isinstance<py::none>(captured_inputs)) { | ||||
|         flat_in_captures = tree_flatten(captured_inputs, false); | ||||
|         trace_captures.insert( | ||||
|             trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); | ||||
|         tree_fill(captured_inputs, trace_captures); | ||||
|       } | ||||
|  | ||||
|       tree_cache().insert({this->fun_id, py_outputs}); | ||||
|       auto [outputs, py_outputs] = tree_flatten_with_structure( | ||||
|           std::move(fun(*tree_unflatten(args, a))), false); | ||||
|  | ||||
|       tree_cache().insert({fun_id, py_outputs}); | ||||
|  | ||||
|       num_outputs = outputs.size(); | ||||
|       if (!py::isinstance<py::none>(captured_outputs)) { | ||||
|         auto flat_out_captures = tree_flatten(captured_outputs, false); | ||||
|         outputs.insert( | ||||
|             outputs.end(), | ||||
|             std::make_move_iterator(flat_out_captures.begin()), | ||||
|             std::make_move_iterator(flat_out_captures.end())); | ||||
|       } | ||||
|  | ||||
|       // Replace tracers with originals in captured inputs | ||||
|       if (!py::isinstance<py::none>(captured_inputs)) { | ||||
|         tree_replace(captured_inputs, trace_captures, flat_in_captures); | ||||
|       } | ||||
|       return outputs; | ||||
|     }; | ||||
|  | ||||
|     // Inputs must be array or tree of arrays | ||||
|     auto inputs = tree_flatten(args, true); | ||||
|     auto inputs = tree_flatten(args, false); | ||||
|     if (!py::isinstance<py::none>(captured_inputs)) { | ||||
|       auto flat_in_captures = tree_flatten(captured_inputs, false); | ||||
|       inputs.insert( | ||||
|           inputs.end(), | ||||
|           std::make_move_iterator(flat_in_captures.begin()), | ||||
|           std::make_move_iterator(flat_in_captures.end())); | ||||
|     } | ||||
|  | ||||
|     // Compile and call | ||||
|     auto outputs = detail::compile(compile_fun, fun_id)(inputs); | ||||
|     if (!py::isinstance<py::none>(captured_outputs)) { | ||||
|       std::vector<array> captures( | ||||
|           std::make_move_iterator(outputs.begin() + num_outputs), | ||||
|           std::make_move_iterator(outputs.end())); | ||||
|       tree_fill(captured_outputs, captures); | ||||
|     } | ||||
|  | ||||
|     // Put the outputs back in the container | ||||
|     py::object py_outputs = tree_cache().at(fun_id); | ||||
| @@ -534,6 +636,8 @@ struct PyCompiledFun { | ||||
|     tree_cache().erase(fun_id); | ||||
|     detail::compile_erase(fun_id); | ||||
|     fun.release().dec_ref(); | ||||
|     captured_inputs.release().dec_ref(); | ||||
|     captured_outputs.release().dec_ref(); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| @@ -601,7 +705,7 @@ void init_transforms(py::module_& m) { | ||||
|   m.def( | ||||
|       "eval", | ||||
|       [](const py::args& args) { | ||||
|         std::vector<array> arrays = tree_flatten(args); | ||||
|         std::vector<array> arrays = tree_flatten(args, false); | ||||
|         { | ||||
|           py::gil_scoped_release nogil; | ||||
|           eval(arrays); | ||||
| @@ -615,8 +719,8 @@ void init_transforms(py::module_& m) { | ||||
|         Args: | ||||
|             *args (arrays or trees of arrays): Each argument can be a single array | ||||
|               or a tree of arrays. If a tree is given the nodes can be a Python | ||||
|               :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be | ||||
|               an :class:`array`. | ||||
|               :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not | ||||
|               arrays are ignored. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "jvp", | ||||
| @@ -859,10 +963,14 @@ void init_transforms(py::module_& m) { | ||||
|       "file"_a); | ||||
|   m.def( | ||||
|       "compile", | ||||
|       [](const py::function& fun) { | ||||
|         return py::cpp_function(PyCompiledFun{fun}); | ||||
|       [](const py::function& fun, | ||||
|          const py::object& inputs, | ||||
|          const py::object& outputs) { | ||||
|         return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); | ||||
|       }, | ||||
|       "fun"_a, | ||||
|       "inputs"_a = std::nullopt, | ||||
|       "outputs"_a = std::nullopt, | ||||
|       R"pbdoc( | ||||
|         compile(fun: function) -> function | ||||
|  | ||||
| @@ -872,6 +980,16 @@ void init_transforms(py::module_& m) { | ||||
|             fun (function): A function which takes a variable number of | ||||
|               :class:`array` or trees of :class:`array` and returns | ||||
|               a variable number of :class:`array` or trees of :class:`array`. | ||||
|             inputs (list or dict, optional): These inputs will be captured during | ||||
|               the function compilation along with the inputs to ``fun``. The ``inputs`` | ||||
|               can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested | ||||
|               lists, dictionaries, or arrays. Leaf nodes that are not | ||||
|               :obj:`array` are ignored. Default: ``None`` | ||||
|             outputs (list or dict, optional): These outputs will be captured and | ||||
|               updated in a compiled function. The ``outputs`` can be a | ||||
|               :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, | ||||
|               dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. | ||||
|               Default: ``None`` | ||||
|  | ||||
|         Returns: | ||||
|             function: A compiled function which has the same input arguments | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import io | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| @@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         cdfdx = mx.grad(outer)(x) | ||||
|         self.assertTrue(mx.allclose(dfdx, cdfdx)) | ||||
|  | ||||
|     def test_compile_capture(self): | ||||
|         # Test update captured state outside compiled function | ||||
|         state = {"y": mx.array(2)} | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state["y"] | ||||
|             return x | ||||
|  | ||||
|         test_state(mx.array(1)) | ||||
|         # Check the state is unchanged | ||||
|         self.assertEqual(state["y"], 2) | ||||
|  | ||||
|         # Check the udpated state is used | ||||
|         state["y"] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Capture list | ||||
|         state = [mx.array(2)] | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state[0] | ||||
|             return x | ||||
|  | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 3) | ||||
|         state[0] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Capture tuple of list | ||||
|         state = ([mx.array(2)],) | ||||
|  | ||||
|         @partial(mx.compile, inputs=state) | ||||
|         def test_state(x): | ||||
|             x = x + state[0][0] | ||||
|             return x | ||||
|  | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 3) | ||||
|         state[0][0] = mx.array(3) | ||||
|         out = test_state(mx.array(1)) | ||||
|         self.assertEqual(out.item(), 4) | ||||
|  | ||||
|         # Test state updated inside compiled function | ||||
|         state = {} | ||||
|  | ||||
|         @partial(mx.compile, outputs=state) | ||||
|         def test_state(x): | ||||
|             state["y"] = x + 3 | ||||
|             return mx.abs(x) | ||||
|  | ||||
|         test_state(mx.array(-1)) | ||||
|         self.assertEqual(state["y"].item(), 2) | ||||
|  | ||||
|         # Test state changed inside compiled function | ||||
|         # triggers recompile | ||||
|         state = {} | ||||
|  | ||||
|         @partial(mx.compile, inputs=state, outputs=state) | ||||
|         def test_state(x): | ||||
|             y = state.get("y", mx.array(0)) | ||||
|             state["y"] = x + y | ||||
|             return x + 2 * y | ||||
|  | ||||
|         test_state(mx.array(1)) | ||||
|         self.assertEqual(state["y"].item(), 1) | ||||
|         test_state(mx.array(1)) | ||||
|         self.assertEqual(state["y"].item(), 2) | ||||
|  | ||||
|     def test_compile_rng(self): | ||||
|         @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) | ||||
|         def fun(): | ||||
|             return mx.random.uniform(shape=(10, 10)) | ||||
|  | ||||
|         self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|         y = dfun_dx(mx.array(1.0)) | ||||
|         self.assertEqual(y.item(), 6.0) | ||||
|  | ||||
|     def test_eval_mixed(self): | ||||
|         x = mx.array(1) + 1 + 1 | ||||
|         y = 0 | ||||
|         z = "hello" | ||||
|         state = [x, y, z] | ||||
|         mx.eval(state) | ||||
|         self.assertEqual(x.item(), 3) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|                 ] | ||||
|             ) | ||||
|  | ||||
|     def test_module_state(self): | ||||
|         m = nn.Linear(10, 1) | ||||
|         m.state["hello"] = "world" | ||||
|         self.assertEqual(m.state["hello"], "world") | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|   | ||||
| @@ -2,47 +2,209 @@ | ||||
|  | ||||
| import inspect | ||||
| import unittest | ||||
| from functools import partial | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.optimizers as opt | ||||
| import mlx.utils | ||||
| import mlx_tests | ||||
| from mlx.utils import tree_flatten, tree_map | ||||
|  | ||||
|  | ||||
| def get_all_optimizers(): | ||||
|     classes = dict() | ||||
|     for name, obj in inspect.getmembers(opt): | ||||
|         if inspect.isclass(obj): | ||||
|             if obj.__name__ not in ["OptimizerState", "Optimizer"]: | ||||
|             if obj.__name__ not in ["Optimizer"]: | ||||
|                 classes[name] = obj | ||||
|     return classes | ||||
|  | ||||
|  | ||||
| def tree_equal(fn, *args): | ||||
|     return all(v for _, v in tree_flatten(tree_map(fn, *args))) | ||||
|  | ||||
|  | ||||
| optimizers_dict = get_all_optimizers() | ||||
|  | ||||
|  | ||||
| class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|     def test_optimizer_state(self): | ||||
|         optim = opt.SGD(0.1) | ||||
|         optim.state["hello"] = "world" | ||||
|         self.assertEqual(optim.state["hello"], "world") | ||||
|  | ||||
|         optim.state = {0: 1} | ||||
|         self.assertEqual(optim.state, {0: 1}) | ||||
|  | ||||
|     def test_optimizers(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             optim = optim_class(0.1) | ||||
|             update = optim.apply_gradients(grads, params) | ||||
|             mx.eval(update) | ||||
|             equal_shape = mlx.utils.tree_map( | ||||
|                 lambda x, y: x.shape == y.shape, params, update | ||||
|             ) | ||||
|             equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) | ||||
|             all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) | ||||
|             self.assertTrue(all_equal) | ||||
|  | ||||
|     def test_types_conserved(self): | ||||
|         params = {"w": mx.ones((5, 5), mx.float16)} | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|         for optim_class in optimizers_dict.values(): | ||||
|             optim = optim_class(0.1) | ||||
|             update = optim.apply_gradients(grads, params) | ||||
|             self.assertEqual(update["w"].dtype, mx.float16) | ||||
|  | ||||
|     def test_sgd(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Implicit init | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         optim.apply_gradients(grads, params) | ||||
|         self.assertTrue( | ||||
|             tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) | ||||
|         ) | ||||
|  | ||||
|     def test_rmsprop(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.RMSprop(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Implicit init | ||||
|         alpha = 0.99 | ||||
|         optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) | ||||
|         optim.apply_gradients(grads, params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adagrad(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.Adagrad(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adadelta(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.AdaDelta(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adam(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: | ||||
|             optim = optimizer(learning_rate=1e-2) | ||||
|             optim.init(params) | ||||
|             self.assertTrue( | ||||
|                 tree_equal( | ||||
|                     lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), | ||||
|                     params, | ||||
|                     optim.state, | ||||
|                 ) | ||||
|             ) | ||||
|             self.assertTrue( | ||||
|                 tree_equal( | ||||
|                     lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||
|                     params, | ||||
|                     optim.state, | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     def test_lion(self): | ||||
|         params = { | ||||
|             "first": [mx.zeros((10,)), mx.zeros((1,))], | ||||
|             "second": mx.zeros((1,)), | ||||
|         } | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|  | ||||
|         # Explicit init | ||||
|         optim = opt.Lion(learning_rate=1e-2) | ||||
|         optim.init(params) | ||||
|         self.assertTrue( | ||||
|             tree_equal( | ||||
|                 lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), | ||||
|                 params, | ||||
|                 optim.state, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_adafactor(self): | ||||
|         x = mx.zeros((5, 5)) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
| @@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase): | ||||
|         x = mx.zeros((5, 5), mx.float16) | ||||
|         grad = mx.ones_like(x) | ||||
|         optimizer = opt.Adafactor() | ||||
|         optimizer.init(x) | ||||
|         for _ in range(2): | ||||
|             xp = optimizer.apply_single(grad, x, optimizer.state) | ||||
|             self.assertEqual(xp.dtype, x.dtype) | ||||
|             self.assertEqual(xp.shape, x.shape) | ||||
|         self.assertEqual(optimizer.state["step"], 2) | ||||
|  | ||||
|     def test_compiled_optimizer(self): | ||||
|         model = nn.Linear(10, 10) | ||||
|         x = mx.random.uniform(shape=(2, 10)) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|  | ||||
|         orig_params = model.parameters() | ||||
|  | ||||
|         def loss(model, x): | ||||
|             return model(x).sum() | ||||
|  | ||||
|         # Uncompiled version | ||||
|         def step(x): | ||||
|             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||
|             optim.update(model, grad) | ||||
|  | ||||
|         step(x) | ||||
|         uncompiled_params = model.parameters() | ||||
|  | ||||
|         # Pure version | ||||
|         def loss(params, x): | ||||
|             model.update(params) | ||||
|             return model(x).sum() | ||||
|  | ||||
|         model.update(orig_params) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|  | ||||
|         @mx.compile | ||||
|         def step(params, opt_state, x): | ||||
|             grad = mx.grad(loss)(params, x) | ||||
|             optim.state = opt_state | ||||
|             params = optim.apply_gradients(grad, params) | ||||
|             return params, optim.state | ||||
|  | ||||
|         optim.init(model.parameters()) | ||||
|         pure_params, _ = step(model.parameters(), optim.state, x) | ||||
|         self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) | ||||
|         self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) | ||||
|  | ||||
|         # Impure version | ||||
|         def loss(model, x): | ||||
|             return model(x).sum() | ||||
|  | ||||
|         model.update(orig_params) | ||||
|         optim = opt.SGD(learning_rate=1e-2, momentum=0.9) | ||||
|         state = [model.state, optim.state] | ||||
|  | ||||
|         @partial(mx.compile, inputs=state, outputs=state) | ||||
|         def step(x): | ||||
|             _, grad = nn.value_and_grad(model, loss)(model, x) | ||||
|             optim.update(model, grad) | ||||
|  | ||||
|         step(x) | ||||
|         impure_params = model.parameters() | ||||
|         self.assertTrue( | ||||
|             mx.allclose(impure_params["weight"], uncompiled_params["weight"]) | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) | ||||
|  | ||||
|     def test_update_lr_compiled(self): | ||||
|         params = {"w": mx.ones((5, 5))} | ||||
|         grads = tree_map(lambda x: mx.ones_like(x), params) | ||||
|         optim = opt.SGD(-1.0) | ||||
|  | ||||
|         @partial(mx.compile, inputs=optim.state) | ||||
|         def update(grads): | ||||
|             return optim.apply_gradients(grads, params) | ||||
|  | ||||
|         result = update(grads) | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) | ||||
|         optim.learning_rate = -2.0 | ||||
|         result = update(grads) | ||||
|         self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun