mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
e5e816a5ef
commit
1b97b2958b
@ -1,19 +0,0 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{#{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}#}
|
@ -11,6 +11,7 @@ Module
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
Module.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
|
23
docs/src/python/optimizer.rst
Normal file
23
docs/src/python/optimizer.rst
Normal file
@ -0,0 +1,23 @@
|
||||
Optimizer
|
||||
=========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Optimizer
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.apply_gradients
|
||||
Optimizer.init
|
||||
Optimizer.update
|
@ -29,14 +29,16 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizer
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
|
@ -191,10 +191,7 @@ struct CompilerCache {
|
||||
auto is_match = [](const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[compiler] Unexpected number of inputs to compiled function:"
|
||||
<< " expected " << in2.size() << " got " << in1.size() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].shape() != in2[i].shape()) {
|
||||
@ -603,7 +600,7 @@ void compile_fuse(
|
||||
shapes,
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
outputs.back().primitive().stream(),
|
||||
old_outputs.back().primitive().stream(),
|
||||
inputs,
|
||||
old_outputs,
|
||||
std::move(fused_tape),
|
||||
|
@ -66,6 +66,19 @@ class Module(dict):
|
||||
"""Boolean indicating if the model is in training mode."""
|
||||
return self._training
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The module's state dictionary
|
||||
|
||||
The module's state dictionary contains any attribute set on the
|
||||
module including parameters in :meth:`Module.parameters`
|
||||
|
||||
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
|
||||
a reference to the module's state. Updates to it will be reflected in
|
||||
the original module.
|
||||
"""
|
||||
return self
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
|
@ -7,39 +7,14 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class OptimizerState(dict):
|
||||
"""The optimizer state implements a recursively defined
|
||||
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
||||
state is an :class:`OptimizerState`.
|
||||
|
||||
.. note::
|
||||
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
||||
the key to the ``default`` value if the ``key`` was not present in the
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = OptimizerState()
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default):
|
||||
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
||||
if key not in self:
|
||||
self[key] = default
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
class Optimizer:
|
||||
"""The base class for all optimizers. It allows us to implement an
|
||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||
|
||||
Attributes:
|
||||
state (OptimizerState): It holds the optimizer's state dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = OptimizerState()
|
||||
self._initialized = False
|
||||
self._state = {}
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
@ -52,7 +27,41 @@ class Optimizer:
|
||||
"""
|
||||
model.update(self.apply_gradients(gradients, model))
|
||||
|
||||
def apply_gradients(self, gradients: dict, model: dict):
|
||||
def init(self, parameters: dict):
|
||||
"""Initialize the optimizer's state
|
||||
|
||||
This function can be used to initialize optimizers which have state
|
||||
(like momentum in :class:`SGD`). Using this method is optional as the
|
||||
optimizer will initialize itself if the state is not yet set. However,
|
||||
there are some cases where explicit initialization is useful in order
|
||||
to have access to the :attr:`Optimizer.state` before the first call to
|
||||
:meth:`Optimizer.update`.
|
||||
|
||||
Args:
|
||||
model (dict): A Python tree of parameters.
|
||||
|
||||
Example:
|
||||
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
||||
>>> model = nn.Linear(2, 2)
|
||||
>>> optimizer.init(model.trainable_parameters())
|
||||
>>> optimizer.state
|
||||
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
|
||||
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
|
||||
"""
|
||||
self._state.update(tree_map(lambda x: {}, parameters))
|
||||
tree_map(self.init_single, parameters, self._state)
|
||||
self._initialized = True
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
state initialization.
|
||||
|
||||
Args:
|
||||
parameter (mx.array): A single parameter that will be optimized.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_gradients(self, gradients: dict, parameters: dict):
|
||||
"""Apply the gradients to the parameters and return the updated parameters.
|
||||
|
||||
Can be used to update a model via
|
||||
@ -61,19 +70,41 @@ class Optimizer:
|
||||
|
||||
Args:
|
||||
gradients (dict): A Python tree of gradients.
|
||||
model (dict): A Python tree of parameters. It can be a superset of
|
||||
the gradients. In that case the returned python tree
|
||||
will be of the same structure as the gradients.
|
||||
parameters (dict): A Python tree of parameters. It can be a
|
||||
superset of the gradients. In that case the returned python
|
||||
tree will be of the same structure as the gradients.
|
||||
"""
|
||||
return tree_map(self.apply_single, gradients, model, self.state)
|
||||
if not self._initialized:
|
||||
self.init(gradients)
|
||||
return tree_map(self.apply_single, gradients, parameters, self.state)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
update."""
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""To be extended by derived classes to implement the optimizer's update.
|
||||
|
||||
Args:
|
||||
gradient (mx.array): The ``parameter`` gradient.
|
||||
parameter (mx.array): The ``parameter`` to update.
|
||||
state (dict): The optimizer's state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The optimizer's state dictionary."""
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.state["learning_rate"]
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, learning_rate: mx.array):
|
||||
self.state["learning_rate"] = mx.array(learning_rate)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
@ -113,9 +144,11 @@ class SGD(Optimizer):
|
||||
self.dampening = dampening
|
||||
self.nesterov = nesterov
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
|
||||
@ -123,24 +156,21 @@ class SGD(Optimizer):
|
||||
gradient += self.weight_decay * parameter
|
||||
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * gradient
|
||||
|
||||
v = self.momentum * state.get("v")
|
||||
if self.dampening > 0:
|
||||
v = (
|
||||
state.get("v", (self.dampening / self.momentum) * gradient)
|
||||
* self.momentum
|
||||
)
|
||||
v += (1 - self.dampening) * gradient
|
||||
else:
|
||||
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
|
||||
v += gradient
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient + self.momentum * v
|
||||
else:
|
||||
update = v
|
||||
|
||||
state["v"] = v
|
||||
return parameter - self.learning_rate * update
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * update
|
||||
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
@ -177,15 +207,17 @@ class RMSprop(Optimizer):
|
||||
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
alpha = self.alpha
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = state["v"]
|
||||
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
||||
state["v"] = v
|
||||
|
||||
@ -222,16 +254,17 @@ class Adagrad(Optimizer):
|
||||
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = v + mx.square(gradient)
|
||||
v = state["v"] + mx.square(gradient)
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
||||
@ -274,17 +307,20 @@ class AdaDelta(Optimizer):
|
||||
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
state["u"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
||||
:math:`u` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
rho = self.rho
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
u = state.get("u", mx.zeros_like(gradient))
|
||||
v = state["v"]
|
||||
u = state["u"]
|
||||
|
||||
v = rho * v + (1 - rho) * mx.square(gradient)
|
||||
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
||||
@ -329,17 +365,20 @@ class Adam(Optimizer):
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adam parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", gradient)
|
||||
v = state.get("v", mx.square(gradient))
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = b2 * v + (1 - b2) * mx.square(gradient)
|
||||
state["m"] = m
|
||||
@ -385,15 +424,14 @@ class AdamW(Adam):
|
||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the AdamW parameter update by modifying the parameters
|
||||
passed into Adam.
|
||||
"""
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
return super().apply_single(
|
||||
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
||||
gradient, parameter * (1 - lr * self.weight_decay), state
|
||||
)
|
||||
|
||||
|
||||
@ -430,17 +468,20 @@ class Adamax(Adam):
|
||||
f"Epsilon value should be >=0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adamax parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", mx.zeros_like(gradient))
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = mx.maximum(b2 * v, mx.abs(gradient))
|
||||
@ -489,16 +530,18 @@ class Lion(Optimizer):
|
||||
self.betas = betas
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Lion parameter update and stores :math:`m`
|
||||
in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
weight_decay = self.weight_decay
|
||||
|
||||
m = state.get("m", gradient)
|
||||
m = state["m"]
|
||||
c = b1 * m + (1 - b1) * gradient
|
||||
state["m"] = b2 * m + (1 - b2) * gradient
|
||||
if weight_decay > 0:
|
||||
@ -552,7 +595,8 @@ class Adafactor(Optimizer):
|
||||
warmup_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learning_rate = learning_rate
|
||||
if learning_rate is not None:
|
||||
self.learning_rate = learning_rate
|
||||
self.eps = eps
|
||||
self.clip_threshold = clip_threshold
|
||||
self.decay_rate = decay_rate
|
||||
@ -562,14 +606,29 @@ class Adafactor(Optimizer):
|
||||
self.relative_step = relative_step
|
||||
self.warmup_init = warmup_init
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["step"] = 0
|
||||
if parameter.ndim >= 2:
|
||||
shape = parameter.shape
|
||||
dtype = parameter.dtype
|
||||
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
|
||||
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
|
||||
else:
|
||||
state["exp_avg_sq"] = mx.zeros_like(parameter)
|
||||
|
||||
if self.beta_1 is not None:
|
||||
state["exp_avg"] = mx.zeros_like(parameter)
|
||||
|
||||
def _compute_rms(self, inputs):
|
||||
return mx.sqrt(mx.mean(mx.square(inputs)))
|
||||
|
||||
def _compute_learning_rate(self, step, parameter_rms):
|
||||
relative_step_size = self.learning_rate
|
||||
if self.relative_step:
|
||||
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
||||
else:
|
||||
relative_step_size = self.learning_rate.astype(parameter_rms)
|
||||
|
||||
parameter_scale = 1.0
|
||||
if self.scale_parameter:
|
||||
@ -585,13 +644,11 @@ class Adafactor(Optimizer):
|
||||
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adafactor parameter and state update."""
|
||||
gradient_shape = gradient.shape
|
||||
factored = len(gradient_shape) >= 2
|
||||
step = state.get("step", 0) + 1
|
||||
factored = gradient.ndim >= 2
|
||||
|
||||
step = state["step"] + 1
|
||||
state["step"] = step
|
||||
use_first_moment = self.beta_1 is not None
|
||||
|
||||
@ -601,15 +658,8 @@ class Adafactor(Optimizer):
|
||||
update = mx.square(gradient) + self.eps[0]
|
||||
|
||||
if factored:
|
||||
exp_avg_sq_row = state.get(
|
||||
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
|
||||
)
|
||||
exp_avg_sq_col = state.get(
|
||||
"exp_avg_sq_col",
|
||||
mx.zeros(
|
||||
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
|
||||
),
|
||||
)
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
|
||||
(1 - beta_2) * mx.mean(update, axis=-1)
|
||||
)
|
||||
@ -621,7 +671,7 @@ class Adafactor(Optimizer):
|
||||
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update = update * gradient
|
||||
else:
|
||||
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
update = mx.rsqrt(exp_avg_sq) * gradient
|
||||
@ -632,7 +682,7 @@ class Adafactor(Optimizer):
|
||||
update = learning_rate * update
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
|
||||
state["exp_avg"] = exp_avg
|
||||
update = exp_avg
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
@ -13,13 +14,55 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::random;
|
||||
|
||||
class PyKeySequence {
|
||||
public:
|
||||
explicit PyKeySequence(uint64_t seed) {
|
||||
state_.append(key(seed));
|
||||
}
|
||||
|
||||
void seed(uint64_t seed) {
|
||||
state_[0] = key(seed);
|
||||
}
|
||||
|
||||
array next() {
|
||||
auto out = split(py::cast<array>(state_[0]));
|
||||
state_[0] = out.first;
|
||||
return out.second;
|
||||
}
|
||||
|
||||
py::list state() {
|
||||
return state_;
|
||||
}
|
||||
|
||||
void release() {
|
||||
py::gil_scoped_acquire gil;
|
||||
state_.release().dec_ref();
|
||||
}
|
||||
|
||||
private:
|
||||
py::list state_;
|
||||
};
|
||||
|
||||
PyKeySequence& default_key() {
|
||||
auto get_current_time_seed = []() {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch())
|
||||
.count();
|
||||
};
|
||||
static PyKeySequence ks(get_current_time_seed());
|
||||
return ks;
|
||||
}
|
||||
|
||||
void init_random(py::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"random",
|
||||
"mlx.core.random: functionality related to random number generation");
|
||||
|
||||
m.attr("state") = default_key().state();
|
||||
m.def(
|
||||
"seed",
|
||||
&seed,
|
||||
[](uint64_t seed) { default_key().seed(seed); },
|
||||
"seed"_a,
|
||||
R"pbdoc(
|
||||
Seed the global PRNG.
|
||||
@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return uniform(
|
||||
to_array(low),
|
||||
to_array(high),
|
||||
@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
|
||||
std::optional<Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return normal(shape, type.value_or(float32), loc, scale, key, s);
|
||||
},
|
||||
|
||||
"shape"_a = std::vector<int>{},
|
||||
"dtype"_a = std::optional{float32},
|
||||
"loc"_a = 0.0,
|
||||
@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return randint(
|
||||
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||
},
|
||||
@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto p = to_array(p_);
|
||||
if (shape.has_value()) {
|
||||
return bernoulli(p, shape.value(), key, s);
|
||||
@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
auto lower = to_array(lower_);
|
||||
auto upper = to_array(upper_);
|
||||
auto t = type.value_or(float32);
|
||||
@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
std::optional<Dtype> type,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return gumbel(shape, type.value_or(float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<array>& key,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (shape.has_value() && num_samples.has_value()) {
|
||||
throw std::invalid_argument(
|
||||
"[categorical] At most one of shape or num_samples can be specified.");
|
||||
@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
|
||||
Returns:
|
||||
array: The ``shape``-sized output array with type ``uint32``.
|
||||
)pbdoc");
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
|
||||
}
|
||||
|
@ -135,6 +135,64 @@ py::object tree_map(
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
size_t fun_id;
|
||||
py::object captured_inputs;
|
||||
py::object captured_outputs;
|
||||
size_t num_outputs{0};
|
||||
|
||||
PyCompiledFun(const py::function& fun)
|
||||
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
|
||||
: fun(fun),
|
||||
fun_id(reinterpret_cast<size_t>(fun.ptr())),
|
||||
captured_inputs(inputs),
|
||||
captured_outputs(outputs) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
@ -505,23 +569,61 @@ struct PyCompiledFun {
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
captured_inputs = std::move(other.captured_inputs);
|
||||
captured_outputs = std::move(other.captured_outputs);
|
||||
num_outputs = other.num_outputs;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function and flatten the outputs
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(this->fun(*tree_unflatten(args, a))), true);
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
std::vector<array> trace_captures;
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
trace_captures.insert(
|
||||
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
|
||||
tree_fill(captured_inputs, trace_captures);
|
||||
}
|
||||
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
auto [outputs, py_outputs] = tree_flatten_with_structure(
|
||||
std::move(fun(*tree_unflatten(args, a))), false);
|
||||
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
|
||||
num_outputs = outputs.size();
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
auto flat_out_captures = tree_flatten(captured_outputs, false);
|
||||
outputs.insert(
|
||||
outputs.end(),
|
||||
std::make_move_iterator(flat_out_captures.begin()),
|
||||
std::make_move_iterator(flat_out_captures.end()));
|
||||
}
|
||||
|
||||
// Replace tracers with originals in captured inputs
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
tree_replace(captured_inputs, trace_captures, flat_in_captures);
|
||||
}
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto inputs = tree_flatten(args, false);
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_in_captures.begin()),
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
if (!py::isinstance<py::none>(captured_outputs)) {
|
||||
std::vector<array> captures(
|
||||
std::make_move_iterator(outputs.begin() + num_outputs),
|
||||
std::make_move_iterator(outputs.end()));
|
||||
tree_fill(captured_outputs, captures);
|
||||
}
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
@ -534,6 +636,8 @@ struct PyCompiledFun {
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
fun.release().dec_ref();
|
||||
captured_inputs.release().dec_ref();
|
||||
captured_outputs.release().dec_ref();
|
||||
}
|
||||
};
|
||||
|
||||
@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
std::vector<array> arrays = tree_flatten(args, false);
|
||||
{
|
||||
py::gil_scoped_release nogil;
|
||||
eval(arrays);
|
||||
@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
|
||||
Args:
|
||||
*args (arrays or trees of arrays): Each argument can be a single array
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
|
||||
arrays are ignored.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) {
|
||||
return py::cpp_function(PyCompiledFun{fun});
|
||||
[](const py::function& fun,
|
||||
const py::object& inputs,
|
||||
const py::object& outputs) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
"outputs"_a = std::nullopt,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
the function compilation along with the inputs to ``fun``. The ``inputs``
|
||||
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
|
||||
lists, dictionaries, or arrays. Leaf nodes that are not
|
||||
:obj:`array` are ignored. Default: ``None``
|
||||
outputs (list or dict, optional): These outputs will be captured and
|
||||
updated in a compiled function. The ``outputs`` can be a
|
||||
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
|
||||
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cdfdx = mx.grad(outer)(x)
|
||||
self.assertTrue(mx.allclose(dfdx, cdfdx))
|
||||
|
||||
def test_compile_capture(self):
|
||||
# Test update captured state outside compiled function
|
||||
state = {"y": mx.array(2)}
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state["y"]
|
||||
return x
|
||||
|
||||
test_state(mx.array(1))
|
||||
# Check the state is unchanged
|
||||
self.assertEqual(state["y"], 2)
|
||||
|
||||
# Check the udpated state is used
|
||||
state["y"] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture list
|
||||
state = [mx.array(2)]
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Capture tuple of list
|
||||
state = ([mx.array(2)],)
|
||||
|
||||
@partial(mx.compile, inputs=state)
|
||||
def test_state(x):
|
||||
x = x + state[0][0]
|
||||
return x
|
||||
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 3)
|
||||
state[0][0] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Test state updated inside compiled function
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, outputs=state)
|
||||
def test_state(x):
|
||||
state["y"] = x + 3
|
||||
return mx.abs(x)
|
||||
|
||||
test_state(mx.array(-1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
# Test state changed inside compiled function
|
||||
# triggers recompile
|
||||
state = {}
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def test_state(x):
|
||||
y = state.get("y", mx.array(0))
|
||||
state["y"] = x + y
|
||||
return x + 2 * y
|
||||
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 1)
|
||||
test_state(mx.array(1))
|
||||
self.assertEqual(state["y"].item(), 2)
|
||||
|
||||
def test_compile_rng(self):
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def fun():
|
||||
return mx.random.uniform(shape=(10, 10))
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = dfun_dx(mx.array(1.0))
|
||||
self.assertEqual(y.item(), 6.0)
|
||||
|
||||
def test_eval_mixed(self):
|
||||
x = mx.array(1) + 1 + 1
|
||||
y = 0
|
||||
z = "hello"
|
||||
state = [x, y, z]
|
||||
mx.eval(state)
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
]
|
||||
)
|
||||
|
||||
def test_module_state(self):
|
||||
m = nn.Linear(10, 1)
|
||||
m.state["hello"] = "world"
|
||||
self.assertEqual(m.state["hello"], "world")
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
@ -2,47 +2,209 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
import mlx_tests
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
def get_all_optimizers():
|
||||
classes = dict()
|
||||
for name, obj in inspect.getmembers(opt):
|
||||
if inspect.isclass(obj):
|
||||
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
|
||||
if obj.__name__ not in ["Optimizer"]:
|
||||
classes[name] = obj
|
||||
return classes
|
||||
|
||||
|
||||
def tree_equal(fn, *args):
|
||||
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
|
||||
|
||||
|
||||
optimizers_dict = get_all_optimizers()
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizer_state(self):
|
||||
optim = opt.SGD(0.1)
|
||||
optim.state["hello"] = "world"
|
||||
self.assertEqual(optim.state["hello"], "world")
|
||||
|
||||
optim.state = {0: 1}
|
||||
self.assertEqual(optim.state, {0: 1})
|
||||
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
def test_types_conserved(self):
|
||||
params = {"w": mx.ones((5, 5), mx.float16)}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
for optim_class in optimizers_dict.values():
|
||||
optim = optim_class(0.1)
|
||||
update = optim.apply_gradients(grads, params)
|
||||
self.assertEqual(update["w"].dtype, mx.float16)
|
||||
|
||||
def test_sgd(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
|
||||
)
|
||||
|
||||
def test_rmsprop(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.RMSprop(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Implicit init
|
||||
alpha = 0.99
|
||||
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
|
||||
optim.apply_gradients(grads, params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
|
||||
)
|
||||
)
|
||||
|
||||
def test_adagrad(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Adagrad(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adadelta(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.AdaDelta(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adam(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
|
||||
optim = optimizer(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_lion(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Lion(learning_rate=1e-2)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
def test_adafactor(self):
|
||||
x = mx.zeros((5, 5))
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
x = mx.zeros((5, 5), mx.float16)
|
||||
grad = mx.ones_like(x)
|
||||
optimizer = opt.Adafactor()
|
||||
optimizer.init(x)
|
||||
for _ in range(2):
|
||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
||||
self.assertEqual(xp.dtype, x.dtype)
|
||||
self.assertEqual(xp.shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
orig_params = model.parameters()
|
||||
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
# Uncompiled version
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
uncompiled_params = model.parameters()
|
||||
|
||||
# Pure version
|
||||
def loss(params, x):
|
||||
model.update(params)
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
|
||||
@mx.compile
|
||||
def step(params, opt_state, x):
|
||||
grad = mx.grad(loss)(params, x)
|
||||
optim.state = opt_state
|
||||
params = optim.apply_gradients(grad, params)
|
||||
return params, optim.state
|
||||
|
||||
optim.init(model.parameters())
|
||||
pure_params, _ = step(model.parameters(), optim.state, x)
|
||||
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
|
||||
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
# Impure version
|
||||
def loss(model, x):
|
||||
return model(x).sum()
|
||||
|
||||
model.update(orig_params)
|
||||
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
|
||||
state = [model.state, optim.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x):
|
||||
_, grad = nn.value_and_grad(model, loss)(model, x)
|
||||
optim.update(model, grad)
|
||||
|
||||
step(x)
|
||||
impure_params = model.parameters()
|
||||
self.assertTrue(
|
||||
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
|
||||
)
|
||||
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
|
||||
|
||||
def test_update_lr_compiled(self):
|
||||
params = {"w": mx.ones((5, 5))}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
optim = opt.SGD(-1.0)
|
||||
|
||||
@partial(mx.compile, inputs=optim.state)
|
||||
def update(grads):
|
||||
return optim.apply_gradients(grads, params)
|
||||
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
|
||||
optim.learning_rate = -2.0
|
||||
result = update(grads)
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user