diff --git a/docs/src/_templates/nn-module-template.rst b/docs/src/_templates/nn-module-template.rst deleted file mode 100644 index 49f018eb5..000000000 --- a/docs/src/_templates/nn-module-template.rst +++ /dev/null @@ -1,19 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - - {#{% block methods %} - - {% if methods %} - .. rubric:: {{ _('Methods') }} - - .. autosummary:: - {% for item in methods %} - {%- if item not in inherited_members and item != '__init__' %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} - {% endblock %}#} diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst index 042a88028..c3a4dfa62 100644 --- a/docs/src/python/nn/module.rst +++ b/docs/src/python/nn/module.rst @@ -11,6 +11,7 @@ Module :toctree: _autosummary Module.training + Module.state .. rubric:: Methods diff --git a/docs/src/python/optimizer.rst b/docs/src/python/optimizer.rst new file mode 100644 index 000000000..cf6034dee --- /dev/null +++ b/docs/src/python/optimizer.rst @@ -0,0 +1,23 @@ +Optimizer +========= + +.. currentmodule:: mlx.optimizers + +.. autoclass:: Optimizer + + + .. rubric:: Attributes + + .. autosummary:: + :toctree: _autosummary + + Optimizer.state + + .. rubric:: Methods + + .. autosummary:: + :toctree: _autosummary + + Optimizer.apply_gradients + Optimizer.init + Optimizer.update diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index fe8632a7e..4ef43d50f 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -29,14 +29,16 @@ model's parameters and the **optimizer state**. # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) +.. toctree:: + + optimizer + .. currentmodule:: mlx.optimizers .. autosummary:: :toctree: _autosummary :template: optimizers-template.rst - OptimizerState - Optimizer SGD RMSprop Adagrad diff --git a/mlx/compile.cpp b/mlx/compile.cpp index c8ee3b0da..e69c442f2 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -191,10 +191,7 @@ struct CompilerCache { auto is_match = [](const std::vector& in1, const std::vector& in2) { if (in1.size() != in2.size()) { - std::ostringstream msg; - msg << "[compiler] Unexpected number of inputs to compiled function:" - << " expected " << in2.size() << " got " << in1.size() << "."; - throw std::invalid_argument(msg.str()); + return false; } for (int i = 0; i < in1.size(); ++i) { if (in1[i].shape() != in2[i].shape()) { @@ -603,7 +600,7 @@ void compile_fuse( shapes, types, std::make_shared( - outputs.back().primitive().stream(), + old_outputs.back().primitive().stream(), inputs, old_outputs, std::move(fused_tape), diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index febbafa78..de7097673 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -66,6 +66,19 @@ class Module(dict): """Boolean indicating if the model is in training mode.""" return self._training + @property + def state(self): + """The module's state dictionary + + The module's state dictionary contains any attribute set on the + module including parameters in :meth:`Module.parameters` + + Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is + a reference to the module's state. Updates to it will be reflected in + the original module. + """ + return self + def _extra_repr(self): return "" diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index b659ec5cf..4a53d4681 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -7,39 +7,14 @@ import mlx.core as mx from mlx.utils import tree_map -class OptimizerState(dict): - """The optimizer state implements a recursively defined - :class:`collections.defaultdict`, namely a missing key in an optimizer - state is an :class:`OptimizerState`. - - .. note:: - :meth:`OptimizerState.get` in contrast to a normal dictionary also sets - the key to the ``default`` value if the ``key`` was not present in the - dictionary. - """ - - def __getitem__(self, key): - if key not in self: - self[key] = OptimizerState() - return super().__getitem__(key) - - def get(self, key, default): - """If ``key`` doesn't exist set its value to ``default`` and then return it.""" - if key not in self: - self[key] = default - return super().__getitem__(key) - - class Optimizer: """The base class for all optimizers. It allows us to implement an optimizer on a per-parameter basis and apply it to a parameter tree. - - Attributes: - state (OptimizerState): It holds the optimizer's state dictionary. """ def __init__(self): - self.state = OptimizerState() + self._initialized = False + self._state = {} def update(self, model: "mlx.nn.Module", gradients: dict): """Apply the gradients to the parameters of the model and update the @@ -52,7 +27,41 @@ class Optimizer: """ model.update(self.apply_gradients(gradients, model)) - def apply_gradients(self, gradients: dict, model: dict): + def init(self, parameters: dict): + """Initialize the optimizer's state + + This function can be used to initialize optimizers which have state + (like momentum in :class:`SGD`). Using this method is optional as the + optimizer will initialize itself if the state is not yet set. However, + there are some cases where explicit initialization is useful in order + to have access to the :attr:`Optimizer.state` before the first call to + :meth:`Optimizer.update`. + + Args: + model (dict): A Python tree of parameters. + + Example: + >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) + >>> model = nn.Linear(2, 2) + >>> optimizer.init(model.trainable_parameters()) + >>> optimizer.state + {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0], + [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}} + """ + self._state.update(tree_map(lambda x: {}, parameters)) + tree_map(self.init_single, parameters, self._state) + self._initialized = True + + def init_single(self, parameter: mx.array, state: dict): + """To be extended by the children classes to implement each optimizer's + state initialization. + + Args: + parameter (mx.array): A single parameter that will be optimized. + """ + raise NotImplementedError() + + def apply_gradients(self, gradients: dict, parameters: dict): """Apply the gradients to the parameters and return the updated parameters. Can be used to update a model via @@ -61,19 +70,41 @@ class Optimizer: Args: gradients (dict): A Python tree of gradients. - model (dict): A Python tree of parameters. It can be a superset of - the gradients. In that case the returned python tree - will be of the same structure as the gradients. + parameters (dict): A Python tree of parameters. It can be a + superset of the gradients. In that case the returned python + tree will be of the same structure as the gradients. """ - return tree_map(self.apply_single, gradients, model, self.state) + if not self._initialized: + self.init(gradients) + return tree_map(self.apply_single, gradients, parameters, self.state) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): - """To be extended by the children classes to implement each optimizer's - update.""" + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """To be extended by derived classes to implement the optimizer's update. + + Args: + gradient (mx.array): The ``parameter`` gradient. + parameter (mx.array): The ``parameter`` to update. + state (dict): The optimizer's state. + """ raise NotImplementedError() + @property + def state(self): + """The optimizer's state dictionary.""" + return self._state + + @state.setter + def state(self, state: dict): + self._state = state + + @property + def learning_rate(self): + return self.state["learning_rate"] + + @learning_rate.setter + def learning_rate(self, learning_rate: mx.array): + self.state["learning_rate"] = mx.array(learning_rate) + class SGD(Optimizer): r"""The stochastic gradient descent optimizer. @@ -113,9 +144,11 @@ class SGD(Optimizer): self.dampening = dampening self.nesterov = nesterov - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the SGD parameter update and stores :math:`v` in the optimizer state.""" @@ -123,24 +156,21 @@ class SGD(Optimizer): gradient += self.weight_decay * parameter if self.momentum <= 0: - return parameter - self.learning_rate * gradient + return parameter - self.learning_rate.astype(gradient.dtype) * gradient + v = self.momentum * state.get("v") if self.dampening > 0: - v = ( - state.get("v", (self.dampening / self.momentum) * gradient) - * self.momentum - ) v += (1 - self.dampening) * gradient else: - v = state.get("v", mx.zeros_like(gradient)) * self.momentum v += gradient if self.nesterov: update = gradient + self.momentum * v else: update = v + state["v"] = v - return parameter - self.learning_rate * update + return parameter - self.learning_rate.astype(gradient.dtype) * update class RMSprop(Optimizer): @@ -177,15 +207,17 @@ class RMSprop(Optimizer): f"RMSprop epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) alpha = self.alpha eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) + v = state["v"] v = alpha * v + (1 - alpha) * mx.square(gradient) state["v"] = v @@ -222,16 +254,17 @@ class Adagrad(Optimizer): f"Adagrad epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adagrad parameter update and stores :math:`v` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - v = v + mx.square(gradient) + v = state["v"] + mx.square(gradient) state["v"] = v return parameter - lr * gradient / (mx.sqrt(v) + eps) @@ -274,17 +307,20 @@ class AdaDelta(Optimizer): f"AdaDelta epsilon should be >0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + state["u"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdaDelta parameter update and stores :math:`v` and :math:`u` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) rho = self.rho eps = self.eps - v = state.get("v", mx.zeros_like(gradient)) - u = state.get("u", mx.zeros_like(gradient)) + v = state["v"] + u = state["u"] v = rho * v + (1 - rho) * mx.square(gradient) d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient @@ -329,17 +365,20 @@ class Adam(Optimizer): self.betas = betas self.eps = eps - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adam parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", gradient) - v = state.get("v", mx.square(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = b2 * v + (1 - b2) * mx.square(gradient) state["m"] = m @@ -385,15 +424,14 @@ class AdamW(Adam): super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the AdamW parameter update by modifying the parameters passed into Adam. """ + lr = self.learning_rate.astype(gradient.dtype) return super().apply_single( - gradient, parameter * (1 - self.learning_rate * self.weight_decay), state + gradient, parameter * (1 - lr * self.weight_decay), state ) @@ -430,17 +468,20 @@ class Adamax(Adam): f"Epsilon value should be >=0, {self.eps} was provided instead" ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["v"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adamax parameter update and stores :math:`v` and :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas eps = self.eps - m = state.get("m", mx.zeros_like(gradient)) - v = state.get("v", mx.zeros_like(gradient)) + m = state["m"] + v = state["v"] m = b1 * m + (1 - b1) * gradient v = mx.maximum(b2 * v, mx.abs(gradient)) @@ -489,16 +530,18 @@ class Lion(Optimizer): self.betas = betas self.weight_decay = weight_decay - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Lion parameter update and stores :math:`m` in the optimizer state.""" - lr = self.learning_rate + lr = self.learning_rate.astype(gradient.dtype) b1, b2 = self.betas weight_decay = self.weight_decay - m = state.get("m", gradient) + m = state["m"] c = b1 * m + (1 - b1) * gradient state["m"] = b2 * m + (1 - b2) * gradient if weight_decay > 0: @@ -552,7 +595,8 @@ class Adafactor(Optimizer): warmup_init: bool = False, ): super().__init__() - self.learning_rate = learning_rate + if learning_rate is not None: + self.learning_rate = learning_rate self.eps = eps self.clip_threshold = clip_threshold self.decay_rate = decay_rate @@ -562,14 +606,29 @@ class Adafactor(Optimizer): self.relative_step = relative_step self.warmup_init = warmup_init + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["step"] = 0 + if parameter.ndim >= 2: + shape = parameter.shape + dtype = parameter.dtype + state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype) + state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype) + else: + state["exp_avg_sq"] = mx.zeros_like(parameter) + + if self.beta_1 is not None: + state["exp_avg"] = mx.zeros_like(parameter) + def _compute_rms(self, inputs): return mx.sqrt(mx.mean(mx.square(inputs))) def _compute_learning_rate(self, step, parameter_rms): - relative_step_size = self.learning_rate if self.relative_step: min_step = 1e-6 * step if self.warmup_init else 1e-2 relative_step_size = min(min_step, 1 / math.sqrt(step)) + else: + relative_step_size = self.learning_rate.astype(parameter_rms) parameter_scale = 1.0 if self.scale_parameter: @@ -585,13 +644,11 @@ class Adafactor(Optimizer): mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) ) - def apply_single( - self, gradient: mx.array, parameter: mx.array, state: OptimizerState - ): + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs the Adafactor parameter and state update.""" - gradient_shape = gradient.shape - factored = len(gradient_shape) >= 2 - step = state.get("step", 0) + 1 + factored = gradient.ndim >= 2 + + step = state["step"] + 1 state["step"] = step use_first_moment = self.beta_1 is not None @@ -601,15 +658,8 @@ class Adafactor(Optimizer): update = mx.square(gradient) + self.eps[0] if factored: - exp_avg_sq_row = state.get( - "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) - ) - exp_avg_sq_col = state.get( - "exp_avg_sq_col", - mx.zeros( - gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype - ), - ) + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( (1 - beta_2) * mx.mean(update, axis=-1) ) @@ -621,7 +671,7 @@ class Adafactor(Optimizer): update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) update = update * gradient else: - exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) + exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) state["exp_avg_sq"] = exp_avg_sq update = mx.rsqrt(exp_avg_sq) * gradient @@ -632,7 +682,7 @@ class Adafactor(Optimizer): update = learning_rate * update if use_first_moment: - exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) + exp_avg = state["exp_avg"] exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) state["exp_avg"] = exp_avg update = exp_avg diff --git a/python/src/random.cpp b/python/src/random.cpp index e9140e7d9..bbcb7a2c8 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "python/src/utils.h" @@ -13,13 +14,55 @@ using namespace py::literals; using namespace mlx::core; using namespace mlx::core::random; +class PyKeySequence { + public: + explicit PyKeySequence(uint64_t seed) { + state_.append(key(seed)); + } + + void seed(uint64_t seed) { + state_[0] = key(seed); + } + + array next() { + auto out = split(py::cast(state_[0])); + state_[0] = out.first; + return out.second; + } + + py::list state() { + return state_; + } + + void release() { + py::gil_scoped_acquire gil; + state_.release().dec_ref(); + } + + private: + py::list state_; +}; + +PyKeySequence& default_key() { + auto get_current_time_seed = []() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + }; + static PyKeySequence ks(get_current_time_seed()); + return ks; +} + void init_random(py::module_& parent_module) { auto m = parent_module.def_submodule( "random", "mlx.core.random: functionality related to random number generation"); + + m.attr("state") = default_key().state(); m.def( "seed", - &seed, + [](uint64_t seed) { default_key().seed(seed); }, "seed"_a, R"pbdoc( Seed the global PRNG. @@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return uniform( to_array(low), to_array(high), @@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) { std::optional type, float loc, float scale, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return normal(shape, type.value_or(float32), loc, scale, key, s); }, - "shape"_a = std::vector{}, "dtype"_a = std::optional{float32}, "loc"_a = 0.0, @@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& high, const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return randint( to_array(low), to_array(high), shape, type.value_or(int32), key, s); }, @@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) { "bernoulli", [](const ScalarOrArray& p_, const std::optional> shape, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto p = to_array(p_); if (shape.has_value()) { return bernoulli(p, shape.value(), key, s); @@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) { const ScalarOrArray& upper_, const std::optional> shape_, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); auto lower = to_array(lower_); auto upper = to_array(upper_); auto t = type.value_or(float32); @@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) { "gumbel", [](const std::vector& shape, std::optional type, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); return gumbel(shape, type.value_or(float32), key, s); }, "shape"_a = std::vector{}, @@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) { int axis, const std::optional> shape, const std::optional num_samples, - const std::optional& key, + const std::optional& key_, StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); if (shape.has_value() && num_samples.has_value()) { throw std::invalid_argument( "[categorical] At most one of shape or num_samples can be specified."); @@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) { Returns: array: The ``shape``-sized output array with type ``uint32``. )pbdoc"); + // Register static Python object cleanup before the interpreter exits + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { default_key().release(); })); } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 78f867876..77170414a 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -135,6 +135,64 @@ py::object tree_map( }); } +void tree_visit_update( + py::object tree, + std::function visitor) { + std::function recurse; + recurse = [&](py::handle subtree) { + if (py::isinstance(subtree)) { + auto l = py::cast(subtree); + for (int i = 0; i < l.size(); ++i) { + l[i] = recurse(l[i]); + } + return py::cast(l); + } else if (py::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + return py::cast(subtree); + } else if (py::isinstance(subtree)) { + auto d = py::cast(subtree); + for (auto item : d) { + d[item.first] = recurse(item.second); + } + return py::cast(d); + } else if (py::isinstance(subtree)) { + return visitor(subtree); + } else { + return py::cast(subtree); + } + }; + recurse(tree); +} + +// Fill a pytree (recursive dict or list of dict or list) +// in place with the given arrays +// Non dict or list nodes are ignored +void tree_fill(py::object& tree, const std::vector& values) { + size_t index = 0; + tree_visit_update( + tree, [&](py::handle node) { return py::cast(values[index++]); }); +} + +// Replace all the arrays from the src values with the dst values in the tree +void tree_replace( + py::object& tree, + const std::vector& src, + const std::vector& dst) { + std::unordered_map src_to_dst; + for (int i = 0; i < src.size(); ++i) { + src_to_dst.insert({src[i].id(), dst[i]}); + } + tree_visit_update(tree, [&](py::handle node) { + auto arr = py::cast(node); + if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { + return py::cast(it->second); + } + return py::cast(arr); + }); +} + std::vector tree_flatten(py::object tree, bool strict = true) { std::vector flat_tree; @@ -495,9 +553,15 @@ std::unordered_map& tree_cache() { struct PyCompiledFun { py::function fun; size_t fun_id; + py::object captured_inputs; + py::object captured_outputs; + size_t num_outputs{0}; - PyCompiledFun(const py::function& fun) - : fun(fun), fun_id(reinterpret_cast(fun.ptr())) {} + PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs) + : fun(fun), + fun_id(reinterpret_cast(fun.ptr())), + captured_inputs(inputs), + captured_outputs(outputs) {} PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete; @@ -505,23 +569,61 @@ struct PyCompiledFun { PyCompiledFun(PyCompiledFun&& other) : fun(std::move(other.fun)), fun_id(reinterpret_cast(fun.ptr())) { other.fun_id = 0; + captured_inputs = std::move(other.captured_inputs); + captured_outputs = std::move(other.captured_outputs); + num_outputs = other.num_outputs; }; py::object operator()(const py::args& args) { auto compile_fun = [this, &args](const std::vector& a) { - // Call the python function and flatten the outputs - auto [outputs, py_outputs] = tree_flatten_with_structure( - std::move(this->fun(*tree_unflatten(args, a))), true); + // Put tracers into captured inputs + std::vector flat_in_captures; + std::vector trace_captures; + if (!py::isinstance(captured_inputs)) { + flat_in_captures = tree_flatten(captured_inputs, false); + trace_captures.insert( + trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); + tree_fill(captured_inputs, trace_captures); + } - tree_cache().insert({this->fun_id, py_outputs}); + auto [outputs, py_outputs] = tree_flatten_with_structure( + std::move(fun(*tree_unflatten(args, a))), false); + + tree_cache().insert({fun_id, py_outputs}); + + num_outputs = outputs.size(); + if (!py::isinstance(captured_outputs)) { + auto flat_out_captures = tree_flatten(captured_outputs, false); + outputs.insert( + outputs.end(), + std::make_move_iterator(flat_out_captures.begin()), + std::make_move_iterator(flat_out_captures.end())); + } + + // Replace tracers with originals in captured inputs + if (!py::isinstance(captured_inputs)) { + tree_replace(captured_inputs, trace_captures, flat_in_captures); + } return outputs; }; - // Inputs must be array or tree of arrays - auto inputs = tree_flatten(args, true); + auto inputs = tree_flatten(args, false); + if (!py::isinstance(captured_inputs)) { + auto flat_in_captures = tree_flatten(captured_inputs, false); + inputs.insert( + inputs.end(), + std::make_move_iterator(flat_in_captures.begin()), + std::make_move_iterator(flat_in_captures.end())); + } // Compile and call auto outputs = detail::compile(compile_fun, fun_id)(inputs); + if (!py::isinstance(captured_outputs)) { + std::vector captures( + std::make_move_iterator(outputs.begin() + num_outputs), + std::make_move_iterator(outputs.end())); + tree_fill(captured_outputs, captures); + } // Put the outputs back in the container py::object py_outputs = tree_cache().at(fun_id); @@ -534,6 +636,8 @@ struct PyCompiledFun { tree_cache().erase(fun_id); detail::compile_erase(fun_id); fun.release().dec_ref(); + captured_inputs.release().dec_ref(); + captured_outputs.release().dec_ref(); } }; @@ -601,7 +705,7 @@ void init_transforms(py::module_& m) { m.def( "eval", [](const py::args& args) { - std::vector arrays = tree_flatten(args); + std::vector arrays = tree_flatten(args, false); { py::gil_scoped_release nogil; eval(arrays); @@ -615,8 +719,8 @@ void init_transforms(py::module_& m) { Args: *args (arrays or trees of arrays): Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python - :class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be - an :class:`array`. + :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not + arrays are ignored. )pbdoc"); m.def( "jvp", @@ -859,10 +963,14 @@ void init_transforms(py::module_& m) { "file"_a); m.def( "compile", - [](const py::function& fun) { - return py::cpp_function(PyCompiledFun{fun}); + [](const py::function& fun, + const py::object& inputs, + const py::object& outputs) { + return py::cpp_function(PyCompiledFun{fun, inputs, outputs}); }, "fun"_a, + "inputs"_a = std::nullopt, + "outputs"_a = std::nullopt, R"pbdoc( compile(fun: function) -> function @@ -872,6 +980,16 @@ void init_transforms(py::module_& m) { fun (function): A function which takes a variable number of :class:`array` or trees of :class:`array` and returns a variable number of :class:`array` or trees of :class:`array`. + inputs (list or dict, optional): These inputs will be captured during + the function compilation along with the inputs to ``fun``. The ``inputs`` + can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested + lists, dictionaries, or arrays. Leaf nodes that are not + :obj:`array` are ignored. Default: ``None`` + outputs (list or dict, optional): These outputs will be captured and + updated in a compiled function. The ``outputs`` can be a + :obj:`list` or a :obj:`dict` containing arbitrarily nested lists, + dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. + Default: ``None`` Returns: function: A compiled function which has the same input arguments diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 56dff8b3d..2e0bb1d7f 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,6 +2,7 @@ import io import unittest +from functools import partial import mlx.core as mx import mlx_tests @@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase): cdfdx = mx.grad(outer)(x) self.assertTrue(mx.allclose(dfdx, cdfdx)) + def test_compile_capture(self): + # Test update captured state outside compiled function + state = {"y": mx.array(2)} + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state["y"] + return x + + test_state(mx.array(1)) + # Check the state is unchanged + self.assertEqual(state["y"], 2) + + # Check the udpated state is used + state["y"] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture list + state = [mx.array(2)] + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Capture tuple of list + state = ([mx.array(2)],) + + @partial(mx.compile, inputs=state) + def test_state(x): + x = x + state[0][0] + return x + + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 3) + state[0][0] = mx.array(3) + out = test_state(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Test state updated inside compiled function + state = {} + + @partial(mx.compile, outputs=state) + def test_state(x): + state["y"] = x + 3 + return mx.abs(x) + + test_state(mx.array(-1)) + self.assertEqual(state["y"].item(), 2) + + # Test state changed inside compiled function + # triggers recompile + state = {} + + @partial(mx.compile, inputs=state, outputs=state) + def test_state(x): + y = state.get("y", mx.array(0)) + state["y"] = x + y + return x + 2 * y + + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 1) + test_state(mx.array(1)) + self.assertEqual(state["y"].item(), 2) + + def test_compile_rng(self): + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) + def fun(): + return mx.random.uniform(shape=(10, 10)) + + self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 6619afa67..dc986a19a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase): y = dfun_dx(mx.array(1.0)) self.assertEqual(y.item(), 6.0) + def test_eval_mixed(self): + x = mx.array(1) + 1 + 1 + y = 0 + z = "hello" + state = [x, y, z] + mx.eval(state) + self.assertEqual(x.item(), 3) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index d7b84bbf6..201665f7f 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase): ] ) + def test_module_state(self): + m = nn.Linear(10, 1) + m.state["hello"] = "world" + self.assertEqual(m.state["hello"], "world") + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self): diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 59046184f..f894a7510 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -2,47 +2,209 @@ import inspect import unittest +from functools import partial import mlx.core as mx +import mlx.nn as nn import mlx.optimizers as opt import mlx.utils import mlx_tests +from mlx.utils import tree_flatten, tree_map def get_all_optimizers(): classes = dict() for name, obj in inspect.getmembers(opt): if inspect.isclass(obj): - if obj.__name__ not in ["OptimizerState", "Optimizer"]: + if obj.__name__ not in ["Optimizer"]: classes[name] = obj return classes +def tree_equal(fn, *args): + return all(v for _, v in tree_flatten(tree_map(fn, *args))) + + optimizers_dict = get_all_optimizers() class TestOptimizers(mlx_tests.MLXTestCase): + def test_optimizer_state(self): + optim = opt.SGD(0.1) + optim.state["hello"] = "world" + self.assertEqual(optim.state["hello"], "world") + + optim.state = {0: 1} + self.assertEqual(optim.state, {0: 1}) + def test_optimizers(self): params = { "first": [mx.zeros((10,)), mx.zeros((1,))], "second": mx.zeros((1,)), } - grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) + grads = tree_map(lambda x: mx.ones_like(x), params) for optim_class in optimizers_dict.values(): optim = optim_class(0.1) update = optim.apply_gradients(grads, params) mx.eval(update) - equal_shape = mlx.utils.tree_map( - lambda x, y: x.shape == y.shape, params, update - ) + equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update) all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) self.assertTrue(all_equal) + def test_types_conserved(self): + params = {"w": mx.ones((5, 5), mx.float16)} + grads = tree_map(lambda x: mx.ones_like(x), params) + for optim_class in optimizers_dict.values(): + optim = optim_class(0.1) + update = optim.apply_gradients(grads, params) + self.assertEqual(update["w"].dtype, mx.float16) + + def test_sgd(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state) + ) + + def test_rmsprop(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.RMSprop(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + alpha = 0.99 + optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal( + lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state + ) + ) + + def test_adagrad(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Adagrad(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adadelta(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.AdaDelta(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_adam(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]: + optim = optimizer(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + def test_lion(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Lion(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + def test_adafactor(self): x = mx.zeros((5, 5)) grad = mx.ones_like(x) optimizer = opt.Adafactor() + optimizer.init(x) for _ in range(2): xp = optimizer.apply_single(grad, x, optimizer.state) self.assertEqual(xp.dtype, x.dtype) @@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase): x = mx.zeros((5, 5), mx.float16) grad = mx.ones_like(x) optimizer = opt.Adafactor() + optimizer.init(x) for _ in range(2): xp = optimizer.apply_single(grad, x, optimizer.state) self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.shape, x.shape) self.assertEqual(optimizer.state["step"], 2) + def test_compiled_optimizer(self): + model = nn.Linear(10, 10) + x = mx.random.uniform(shape=(2, 10)) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + orig_params = model.parameters() + + def loss(model, x): + return model(x).sum() + + # Uncompiled version + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + uncompiled_params = model.parameters() + + # Pure version + def loss(params, x): + model.update(params) + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + + @mx.compile + def step(params, opt_state, x): + grad = mx.grad(loss)(params, x) + optim.state = opt_state + params = optim.apply_gradients(grad, params) + return params, optim.state + + optim.init(model.parameters()) + pure_params, _ = step(model.parameters(), optim.state, x) + self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"])) + self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"])) + + # Impure version + def loss(model, x): + return model(x).sum() + + model.update(orig_params) + optim = opt.SGD(learning_rate=1e-2, momentum=0.9) + state = [model.state, optim.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x): + _, grad = nn.value_and_grad(model, loss)(model, x) + optim.update(model, grad) + + step(x) + impure_params = model.parameters() + self.assertTrue( + mx.allclose(impure_params["weight"], uncompiled_params["weight"]) + ) + self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"])) + + def test_update_lr_compiled(self): + params = {"w": mx.ones((5, 5))} + grads = tree_map(lambda x: mx.ones_like(x), params) + optim = opt.SGD(-1.0) + + @partial(mx.compile, inputs=optim.state) + def update(grads): + return optim.apply_gradients(grads, params) + + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0))) + optim.learning_rate = -2.0 + result = update(grads) + self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0))) + if __name__ == "__main__": unittest.main()