Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-02-07 17:29:22 -08:00 committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 723 additions and 157 deletions

View File

@ -1,19 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@ -11,6 +11,7 @@ Module
:toctree: _autosummary :toctree: _autosummary
Module.training Module.training
Module.state
.. rubric:: Methods .. rubric:: Methods

View File

@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@ -29,14 +29,16 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state. # Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
.. toctree::
optimizer
.. currentmodule:: mlx.optimizers .. currentmodule:: mlx.optimizers
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
:template: optimizers-template.rst :template: optimizers-template.rst
OptimizerState
Optimizer
SGD SGD
RMSprop RMSprop
Adagrad Adagrad

View File

@ -191,10 +191,7 @@ struct CompilerCache {
auto is_match = [](const std::vector<array>& in1, auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) { const std::vector<array>& in2) {
if (in1.size() != in2.size()) { if (in1.size() != in2.size()) {
std::ostringstream msg; return false;
msg << "[compiler] Unexpected number of inputs to compiled function:"
<< " expected " << in2.size() << " got " << in1.size() << ".";
throw std::invalid_argument(msg.str());
} }
for (int i = 0; i < in1.size(); ++i) { for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) { if (in1[i].shape() != in2[i].shape()) {
@ -603,7 +600,7 @@ void compile_fuse(
shapes, shapes,
types, types,
std::make_shared<Compiled>( std::make_shared<Compiled>(
outputs.back().primitive().stream(), old_outputs.back().primitive().stream(),
inputs, inputs,
old_outputs, old_outputs,
std::move(fused_tape), std::move(fused_tape),

View File

@ -66,6 +66,19 @@ class Module(dict):
"""Boolean indicating if the model is in training mode.""" """Boolean indicating if the model is in training mode."""
return self._training return self._training
@property
def state(self):
"""The module's state dictionary
The module's state dictionary contains any attribute set on the
module including parameters in :meth:`Module.parameters`
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
a reference to the module's state. Updates to it will be reflected in
the original module.
"""
return self
def _extra_repr(self): def _extra_repr(self):
return "" return ""

View File

@ -7,39 +7,14 @@ import mlx.core as mx
from mlx.utils import tree_map from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer: class Optimizer:
"""The base class for all optimizers. It allows us to implement an """The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree. optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
""" """
def __init__(self): def __init__(self):
self.state = OptimizerState() self._initialized = False
self._state = {}
def update(self, model: "mlx.nn.Module", gradients: dict): def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the """Apply the gradients to the parameters of the model and update the
@ -52,7 +27,41 @@ class Optimizer:
""" """
model.update(self.apply_gradients(gradients, model)) model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict): def init(self, parameters: dict):
"""Initialize the optimizer's state
This function can be used to initialize optimizers which have state
(like momentum in :class:`SGD`). Using this method is optional as the
optimizer will initialize itself if the state is not yet set. However,
there are some cases where explicit initialization is useful in order
to have access to the :attr:`Optimizer.state` before the first call to
:meth:`Optimizer.update`.
Args:
model (dict): A Python tree of parameters.
Example:
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
>>> model = nn.Linear(2, 2)
>>> optimizer.init(model.trainable_parameters())
>>> optimizer.state
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
self._initialized = True
def init_single(self, parameter: mx.array, state: dict):
"""To be extended by the children classes to implement each optimizer's
state initialization.
Args:
parameter (mx.array): A single parameter that will be optimized.
"""
raise NotImplementedError()
def apply_gradients(self, gradients: dict, parameters: dict):
"""Apply the gradients to the parameters and return the updated parameters. """Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via Can be used to update a model via
@ -61,19 +70,41 @@ class Optimizer:
Args: Args:
gradients (dict): A Python tree of gradients. gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of parameters (dict): A Python tree of parameters. It can be a
the gradients. In that case the returned python tree superset of the gradients. In that case the returned python
will be of the same structure as the gradients. tree will be of the same structure as the gradients.
""" """
return tree_map(self.apply_single, gradients, model, self.state) if not self._initialized:
self.init(gradients)
return tree_map(self.apply_single, gradients, parameters, self.state)
def apply_single( def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """To be extended by derived classes to implement the optimizer's update.
):
"""To be extended by the children classes to implement each optimizer's Args:
update.""" gradient (mx.array): The ``parameter`` gradient.
parameter (mx.array): The ``parameter`` to update.
state (dict): The optimizer's state.
"""
raise NotImplementedError() raise NotImplementedError()
@property
def state(self):
"""The optimizer's state dictionary."""
return self._state
@state.setter
def state(self, state: dict):
self._state = state
@property
def learning_rate(self):
return self.state["learning_rate"]
@learning_rate.setter
def learning_rate(self, learning_rate: mx.array):
self.state["learning_rate"] = mx.array(learning_rate)
class SGD(Optimizer): class SGD(Optimizer):
r"""The stochastic gradient descent optimizer. r"""The stochastic gradient descent optimizer.
@ -113,9 +144,11 @@ class SGD(Optimizer):
self.dampening = dampening self.dampening = dampening
self.nesterov = nesterov self.nesterov = nesterov
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the SGD parameter update and stores :math:`v` in the """Performs the SGD parameter update and stores :math:`v` in the
optimizer state.""" optimizer state."""
@ -123,24 +156,21 @@ class SGD(Optimizer):
gradient += self.weight_decay * parameter gradient += self.weight_decay * parameter
if self.momentum <= 0: if self.momentum <= 0:
return parameter - self.learning_rate * gradient return parameter - self.learning_rate.astype(gradient.dtype) * gradient
v = self.momentum * state.get("v")
if self.dampening > 0: if self.dampening > 0:
v = (
state.get("v", (self.dampening / self.momentum) * gradient)
* self.momentum
)
v += (1 - self.dampening) * gradient v += (1 - self.dampening) * gradient
else: else:
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
v += gradient v += gradient
if self.nesterov: if self.nesterov:
update = gradient + self.momentum * v update = gradient + self.momentum * v
else: else:
update = v update = v
state["v"] = v state["v"] = v
return parameter - self.learning_rate * update return parameter - self.learning_rate.astype(gradient.dtype) * update
class RMSprop(Optimizer): class RMSprop(Optimizer):
@ -177,15 +207,17 @@ class RMSprop(Optimizer):
f"RMSprop epsilon should be >0, {self.eps} was provided instead" f"RMSprop epsilon should be >0, {self.eps} was provided instead"
) )
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
alpha = self.alpha alpha = self.alpha
eps = self.eps eps = self.eps
v = state.get("v", mx.zeros_like(gradient)) v = state["v"]
v = alpha * v + (1 - alpha) * mx.square(gradient) v = alpha * v + (1 - alpha) * mx.square(gradient)
state["v"] = v state["v"] = v
@ -222,16 +254,17 @@ class Adagrad(Optimizer):
f"Adagrad epsilon should be >0, {self.eps} was provided instead" f"Adagrad epsilon should be >0, {self.eps} was provided instead"
) )
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adagrad parameter update and stores :math:`v` in the """Performs the Adagrad parameter update and stores :math:`v` in the
optimizer state.""" optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
eps = self.eps eps = self.eps
v = state.get("v", mx.zeros_like(gradient)) v = state["v"] + mx.square(gradient)
v = v + mx.square(gradient)
state["v"] = v state["v"] = v
return parameter - lr * gradient / (mx.sqrt(v) + eps) return parameter - lr * gradient / (mx.sqrt(v) + eps)
@ -274,17 +307,20 @@ class AdaDelta(Optimizer):
f"AdaDelta epsilon should be >0, {self.eps} was provided instead" f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
) )
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["v"] = mx.zeros_like(parameter)
state["u"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdaDelta parameter update and stores :math:`v` and """Performs the AdaDelta parameter update and stores :math:`v` and
:math:`u` in the optimizer state.""" :math:`u` in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
rho = self.rho rho = self.rho
eps = self.eps eps = self.eps
v = state.get("v", mx.zeros_like(gradient)) v = state["v"]
u = state.get("u", mx.zeros_like(gradient)) u = state["u"]
v = rho * v + (1 - rho) * mx.square(gradient) v = rho * v + (1 - rho) * mx.square(gradient)
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
@ -329,17 +365,20 @@ class Adam(Optimizer):
self.betas = betas self.betas = betas
self.eps = eps self.eps = eps
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adam parameter update and stores :math:`v` and """Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state.""" :math:`m` in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas b1, b2 = self.betas
eps = self.eps eps = self.eps
m = state.get("m", gradient) m = state["m"]
v = state.get("v", mx.square(gradient)) v = state["v"]
m = b1 * m + (1 - b1) * gradient m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient) v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m state["m"] = m
@ -385,15 +424,14 @@ class AdamW(Adam):
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps) super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
self.weight_decay = weight_decay self.weight_decay = weight_decay
def apply_single( def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the AdamW parameter update by modifying the parameters """Performs the AdamW parameter update by modifying the parameters
passed into Adam. passed into Adam.
""" """
lr = self.learning_rate.astype(gradient.dtype)
return super().apply_single( return super().apply_single(
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state gradient, parameter * (1 - lr * self.weight_decay), state
) )
@ -430,17 +468,20 @@ class Adamax(Adam):
f"Epsilon value should be >=0, {self.eps} was provided instead" f"Epsilon value should be >=0, {self.eps} was provided instead"
) )
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adamax parameter update and stores :math:`v` and """Performs the Adamax parameter update and stores :math:`v` and
:math:`m` in the optimizer state.""" :math:`m` in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas b1, b2 = self.betas
eps = self.eps eps = self.eps
m = state.get("m", mx.zeros_like(gradient)) m = state["m"]
v = state.get("v", mx.zeros_like(gradient)) v = state["v"]
m = b1 * m + (1 - b1) * gradient m = b1 * m + (1 - b1) * gradient
v = mx.maximum(b2 * v, mx.abs(gradient)) v = mx.maximum(b2 * v, mx.abs(gradient))
@ -489,16 +530,18 @@ class Lion(Optimizer):
self.betas = betas self.betas = betas
self.weight_decay = weight_decay self.weight_decay = weight_decay
def apply_single( def init_single(self, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState """Initialize optimizer state"""
): state["m"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Lion parameter update and stores :math:`m` """Performs the Lion parameter update and stores :math:`m`
in the optimizer state.""" in the optimizer state."""
lr = self.learning_rate lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas b1, b2 = self.betas
weight_decay = self.weight_decay weight_decay = self.weight_decay
m = state.get("m", gradient) m = state["m"]
c = b1 * m + (1 - b1) * gradient c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0: if weight_decay > 0:
@ -552,6 +595,7 @@ class Adafactor(Optimizer):
warmup_init: bool = False, warmup_init: bool = False,
): ):
super().__init__() super().__init__()
if learning_rate is not None:
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.eps = eps self.eps = eps
self.clip_threshold = clip_threshold self.clip_threshold = clip_threshold
@ -562,14 +606,29 @@ class Adafactor(Optimizer):
self.relative_step = relative_step self.relative_step = relative_step
self.warmup_init = warmup_init self.warmup_init = warmup_init
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["step"] = 0
if parameter.ndim >= 2:
shape = parameter.shape
dtype = parameter.dtype
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
else:
state["exp_avg_sq"] = mx.zeros_like(parameter)
if self.beta_1 is not None:
state["exp_avg"] = mx.zeros_like(parameter)
def _compute_rms(self, inputs): def _compute_rms(self, inputs):
return mx.sqrt(mx.mean(mx.square(inputs))) return mx.sqrt(mx.mean(mx.square(inputs)))
def _compute_learning_rate(self, step, parameter_rms): def _compute_learning_rate(self, step, parameter_rms):
relative_step_size = self.learning_rate
if self.relative_step: if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2 min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step)) relative_step_size = min(min_step, 1 / math.sqrt(step))
else:
relative_step_size = self.learning_rate.astype(parameter_rms)
parameter_scale = 1.0 parameter_scale = 1.0
if self.scale_parameter: if self.scale_parameter:
@ -585,13 +644,11 @@ class Adafactor(Optimizer):
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
) )
def apply_single( def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the Adafactor parameter and state update.""" """Performs the Adafactor parameter and state update."""
gradient_shape = gradient.shape factored = gradient.ndim >= 2
factored = len(gradient_shape) >= 2
step = state.get("step", 0) + 1 step = state["step"] + 1
state["step"] = step state["step"] = step
use_first_moment = self.beta_1 is not None use_first_moment = self.beta_1 is not None
@ -601,15 +658,8 @@ class Adafactor(Optimizer):
update = mx.square(gradient) + self.eps[0] update = mx.square(gradient) + self.eps[0]
if factored: if factored:
exp_avg_sq_row = state.get( exp_avg_sq_row = state["exp_avg_sq_row"]
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) exp_avg_sq_col = state["exp_avg_sq_col"]
)
exp_avg_sq_col = state.get(
"exp_avg_sq_col",
mx.zeros(
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
),
)
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
(1 - beta_2) * mx.mean(update, axis=-1) (1 - beta_2) * mx.mean(update, axis=-1)
) )
@ -621,7 +671,7 @@ class Adafactor(Optimizer):
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
update = update * gradient update = update * gradient
else: else:
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
state["exp_avg_sq"] = exp_avg_sq state["exp_avg_sq"] = exp_avg_sq
update = mx.rsqrt(exp_avg_sq) * gradient update = mx.rsqrt(exp_avg_sq) * gradient
@ -632,7 +682,7 @@ class Adafactor(Optimizer):
update = learning_rate * update update = learning_rate * update
if use_first_moment: if use_first_moment:
exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) exp_avg = state["exp_avg"]
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
state["exp_avg"] = exp_avg state["exp_avg"] = exp_avg
update = exp_avg update = exp_avg

View File

@ -2,6 +2,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h" #include "python/src/utils.h"
@ -13,13 +14,55 @@ using namespace py::literals;
using namespace mlx::core; using namespace mlx::core;
using namespace mlx::core::random; using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
}
array next() {
auto out = split(py::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
};
PyKeySequence& default_key() {
auto get_current_time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
};
static PyKeySequence ks(get_current_time_seed());
return ks;
}
void init_random(py::module_& parent_module) { void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule( auto m = parent_module.def_submodule(
"random", "random",
"mlx.core.random: functionality related to random number generation"); "mlx.core.random: functionality related to random number generation");
m.attr("state") = default_key().state();
m.def( m.def(
"seed", "seed",
&seed, [](uint64_t seed) { default_key().seed(seed); },
"seed"_a, "seed"_a,
R"pbdoc( R"pbdoc(
Seed the global PRNG. Seed the global PRNG.
@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform( return uniform(
to_array(low), to_array(low),
to_array(high), to_array(high),
@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
std::optional<Dtype> type, std::optional<Dtype> type,
float loc, float loc,
float scale, float scale,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s); return normal(shape, type.value_or(float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32}, "dtype"_a = std::optional{float32},
"loc"_a = 0.0, "loc"_a = 0.0,
@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint( return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s); to_array(low), to_array(high), shape, type.value_or(int32), key, s);
}, },
@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
"bernoulli", "bernoulli",
[](const ScalarOrArray& p_, [](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape, const std::optional<std::vector<int>> shape,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_); auto p = to_array(p_);
if (shape.has_value()) { if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s); return bernoulli(p, shape.value(), key, s);
@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& upper_, const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_, const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_); auto lower = to_array(lower_);
auto upper = to_array(upper_); auto upper = to_array(upper_);
auto t = type.value_or(float32); auto t = type.value_or(float32);
@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
"gumbel", "gumbel",
[](const std::vector<int>& shape, [](const std::vector<int>& shape,
std::optional<Dtype> type, std::optional<Dtype> type,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s); return gumbel(shape, type.value_or(float32), key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = std::vector<int>{},
@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
int axis, int axis,
const std::optional<std::vector<int>> shape, const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples, const std::optional<int> num_samples,
const std::optional<array>& key, const std::optional<array>& key_,
StreamOrDevice s) { StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) { if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument( throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified."); "[categorical] At most one of shape or num_samples can be specified.");
@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
Returns: Returns:
array: The ``shape``-sized output array with type ``uint32``. array: The ``shape``-sized output array with type ``uint32``.
)pbdoc"); )pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
} }

View File

@ -135,6 +135,64 @@ py::object tree_map(
}); });
} }
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) { std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree; std::vector<array> flat_tree;
@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
struct PyCompiledFun { struct PyCompiledFun {
py::function fun; py::function fun;
size_t fun_id; size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
size_t num_outputs{0};
PyCompiledFun(const py::function& fun) PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {} : fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs) {}
PyCompiledFun(const PyCompiledFun&) = delete; PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete; PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@ -505,23 +569,61 @@ struct PyCompiledFun {
PyCompiledFun(PyCompiledFun&& other) PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) { : fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0; other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
num_outputs = other.num_outputs;
}; };
py::object operator()(const py::args& args) { py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) { auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function and flatten the outputs // Put tracers into captured inputs
auto [outputs, py_outputs] = tree_flatten_with_structure( std::vector<array> flat_in_captures;
std::move(this->fun(*tree_unflatten(args, a))), true); std::vector<array> trace_captures;
if (!py::isinstance<py::none>(captured_inputs)) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
tree_fill(captured_inputs, trace_captures);
}
tree_cache().insert({this->fun_id, py_outputs}); auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(fun(*tree_unflatten(args, a))), false);
tree_cache().insert({fun_id, py_outputs});
num_outputs = outputs.size();
if (!py::isinstance<py::none>(captured_outputs)) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
outputs.end(),
std::make_move_iterator(flat_out_captures.begin()),
std::make_move_iterator(flat_out_captures.end()));
}
// Replace tracers with originals in captured inputs
if (!py::isinstance<py::none>(captured_inputs)) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs; return outputs;
}; };
// Inputs must be array or tree of arrays auto inputs = tree_flatten(args, false);
auto inputs = tree_flatten(args, true); if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_in_captures.begin()),
std::make_move_iterator(flat_in_captures.end()));
}
// Compile and call // Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs); auto outputs = detail::compile(compile_fun, fun_id)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures);
}
// Put the outputs back in the container // Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id); py::object py_outputs = tree_cache().at(fun_id);
@ -534,6 +636,8 @@ struct PyCompiledFun {
tree_cache().erase(fun_id); tree_cache().erase(fun_id);
detail::compile_erase(fun_id); detail::compile_erase(fun_id);
fun.release().dec_ref(); fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
} }
}; };
@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
m.def( m.def(
"eval", "eval",
[](const py::args& args) { [](const py::args& args) {
std::vector<array> arrays = tree_flatten(args); std::vector<array> arrays = tree_flatten(args, false);
{ {
py::gil_scoped_release nogil; py::gil_scoped_release nogil;
eval(arrays); eval(arrays);
@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
Args: Args:
*args (arrays or trees of arrays): Each argument can be a single array *args (arrays or trees of arrays): Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
an :class:`array`. arrays are ignored.
)pbdoc"); )pbdoc");
m.def( m.def(
"jvp", "jvp",
@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
"file"_a); "file"_a);
m.def( m.def(
"compile", "compile",
[](const py::function& fun) { [](const py::function& fun,
return py::cpp_function(PyCompiledFun{fun}); const py::object& inputs,
const py::object& outputs) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
}, },
"fun"_a, "fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
R"pbdoc( R"pbdoc(
compile(fun: function) -> function compile(fun: function) -> function
@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
fun (function): A function which takes a variable number of fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns :class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`. a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during
the function compilation along with the inputs to ``fun``. The ``inputs``
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
lists, dictionaries, or arrays. Leaf nodes that are not
:obj:`array` are ignored. Default: ``None``
outputs (list or dict, optional): These outputs will be captured and
updated in a compiled function. The ``outputs`` can be a
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None``
Returns: Returns:
function: A compiled function which has the same input arguments function: A compiled function which has the same input arguments

View File

@ -2,6 +2,7 @@
import io import io
import unittest import unittest
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
cdfdx = mx.grad(outer)(x) cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx)) self.assertTrue(mx.allclose(dfdx, cdfdx))
def test_compile_capture(self):
# Test update captured state outside compiled function
state = {"y": mx.array(2)}
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state["y"]
return x
test_state(mx.array(1))
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture list
state = [mx.array(2)]
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture tuple of list
state = ([mx.array(2)],)
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0][0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0][0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Test state updated inside compiled function
state = {}
@partial(mx.compile, outputs=state)
def test_state(x):
state["y"] = x + 3
return mx.abs(x)
test_state(mx.array(-1))
self.assertEqual(state["y"].item(), 2)
# Test state changed inside compiled function
# triggers recompile
state = {}
@partial(mx.compile, inputs=state, outputs=state)
def test_state(x):
y = state.get("y", mx.array(0))
state["y"] = x + y
return x + 2 * y
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 1)
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 2)
def test_compile_rng(self):
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def fun():
return mx.random.uniform(shape=(10, 10))
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
y = dfun_dx(mx.array(1.0)) y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0) self.assertEqual(y.item(), 6.0)
def test_eval_mixed(self):
x = mx.array(1) + 1 + 1
y = 0
z = "hello"
state = [x, y, z]
mx.eval(state)
self.assertEqual(x.item(), 3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
] ]
) )
def test_module_state(self):
m = nn.Linear(10, 1)
m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world")
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):

View File

@ -2,47 +2,209 @@
import inspect import inspect
import unittest import unittest
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
import mlx.utils import mlx.utils
import mlx_tests import mlx_tests
from mlx.utils import tree_flatten, tree_map
def get_all_optimizers(): def get_all_optimizers():
classes = dict() classes = dict()
for name, obj in inspect.getmembers(opt): for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj): if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]: if obj.__name__ not in ["Optimizer"]:
classes[name] = obj classes[name] = obj
return classes return classes
def tree_equal(fn, *args):
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
optimizers_dict = get_all_optimizers() optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase): class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizer_state(self):
optim = opt.SGD(0.1)
optim.state["hello"] = "world"
self.assertEqual(optim.state["hello"], "world")
optim.state = {0: 1}
self.assertEqual(optim.state, {0: 1})
def test_optimizers(self): def test_optimizers(self):
params = { params = {
"first": [mx.zeros((10,)), mx.zeros((1,))], "first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)), "second": mx.zeros((1,)),
} }
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values(): for optim_class in optimizers_dict.values():
optim = optim_class(0.1) optim = optim_class(0.1)
update = optim.apply_gradients(grads, params) update = optim.apply_gradients(grads, params)
mx.eval(update) mx.eval(update)
equal_shape = mlx.utils.tree_map( equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
lambda x, y: x.shape == y.shape, params, update
)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal) self.assertTrue(all_equal)
def test_types_conserved(self):
params = {"w": mx.ones((5, 5), mx.float16)}
grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
self.assertEqual(update["w"].dtype, mx.float16)
def test_sgd(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
)
def test_rmsprop(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.RMSprop(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
alpha = 0.99
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
)
)
def test_adagrad(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Adagrad(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adadelta(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.AdaDelta(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adam(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
optim = optimizer(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_lion(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Lion(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adafactor(self): def test_adafactor(self):
x = mx.zeros((5, 5)) x = mx.zeros((5, 5))
grad = mx.ones_like(x) grad = mx.ones_like(x)
optimizer = opt.Adafactor() optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2): for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state) xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.dtype, x.dtype)
@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase):
x = mx.zeros((5, 5), mx.float16) x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x) grad = mx.ones_like(x)
optimizer = opt.Adafactor() optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2): for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state) xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype) self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape) self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2) self.assertEqual(optimizer.state["step"], 2)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
orig_params = model.parameters()
def loss(model, x):
return model(x).sum()
# Uncompiled version
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
uncompiled_params = model.parameters()
# Pure version
def loss(params, x):
model.update(params)
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
@mx.compile
def step(params, opt_state, x):
grad = mx.grad(loss)(params, x)
optim.state = opt_state
params = optim.apply_gradients(grad, params)
return params, optim.state
optim.init(model.parameters())
pure_params, _ = step(model.parameters(), optim.state, x)
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
# Impure version
def loss(model, x):
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
state = [model.state, optim.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
impure_params = model.parameters()
self.assertTrue(
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
)
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
def test_update_lr_compiled(self):
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
optim = opt.SGD(-1.0)
@partial(mx.compile, inputs=optim.state)
def update(grads):
return optim.apply_gradients(grads, params)
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
optim.learning_rate = -2.0
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()