Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun 2024-02-07 17:29:22 -08:00 committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 723 additions and 157 deletions

View File

@ -1,19 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@ -11,6 +11,7 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods

View File

@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@ -29,14 +29,16 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. toctree::
optimizer
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad

View File

@ -191,10 +191,7 @@ struct CompilerCache {
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
std::ostringstream msg;
msg << "[compiler] Unexpected number of inputs to compiled function:"
<< " expected " << in2.size() << " got " << in1.size() << ".";
throw std::invalid_argument(msg.str());
return false;
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
@ -603,7 +600,7 @@ void compile_fuse(
shapes,
types,
std::make_shared<Compiled>(
outputs.back().primitive().stream(),
old_outputs.back().primitive().stream(),
inputs,
old_outputs,
std::move(fused_tape),

View File

@ -66,6 +66,19 @@ class Module(dict):
"""Boolean indicating if the model is in training mode."""
return self._training
@property
def state(self):
"""The module's state dictionary
The module's state dictionary contains any attribute set on the
module including parameters in :meth:`Module.parameters`
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
a reference to the module's state. Updates to it will be reflected in
the original module.
"""
return self
def _extra_repr(self):
return ""

View File

@ -7,39 +7,14 @@ import mlx.core as mx
from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer:
"""The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
"""
def __init__(self):
self.state = OptimizerState()
self._initialized = False
self._state = {}
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
@ -52,7 +27,41 @@ class Optimizer:
"""
model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict):
def init(self, parameters: dict):
"""Initialize the optimizer's state
This function can be used to initialize optimizers which have state
(like momentum in :class:`SGD`). Using this method is optional as the
optimizer will initialize itself if the state is not yet set. However,
there are some cases where explicit initialization is useful in order
to have access to the :attr:`Optimizer.state` before the first call to
:meth:`Optimizer.update`.
Args:
model (dict): A Python tree of parameters.
Example:
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
>>> model = nn.Linear(2, 2)
>>> optimizer.init(model.trainable_parameters())
>>> optimizer.state
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
self._initialized = True
def init_single(self, parameter: mx.array, state: dict):
"""To be extended by the children classes to implement each optimizer's
state initialization.
Args:
parameter (mx.array): A single parameter that will be optimized.
"""
raise NotImplementedError()
def apply_gradients(self, gradients: dict, parameters: dict):
"""Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via
@ -61,19 +70,41 @@ class Optimizer:
Args:
gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of
the gradients. In that case the returned python tree
will be of the same structure as the gradients.
parameters (dict): A Python tree of parameters. It can be a
superset of the gradients. In that case the returned python
tree will be of the same structure as the gradients.
"""
return tree_map(self.apply_single, gradients, model, self.state)
if not self._initialized:
self.init(gradients)
return tree_map(self.apply_single, gradients, parameters, self.state)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""To be extended by the children classes to implement each optimizer's
update."""
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""To be extended by derived classes to implement the optimizer's update.
Args:
gradient (mx.array): The ``parameter`` gradient.
parameter (mx.array): The ``parameter`` to update.
state (dict): The optimizer's state.
"""
raise NotImplementedError()
@property
def state(self):
"""The optimizer's state dictionary."""
return self._state
@state.setter
def state(self, state: dict):
self._state = state
@property
def learning_rate(self):
return self.state["learning_rate"]
@learning_rate.setter
def learning_rate(self, learning_rate: mx.array):
self.state["learning_rate"] = mx.array(learning_rate)
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.
@ -113,9 +144,11 @@ class SGD(Optimizer):
self.dampening = dampening
self.nesterov = nesterov
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
@ -123,24 +156,21 @@ class SGD(Optimizer):
gradient += self.weight_decay * parameter
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
return parameter - self.learning_rate.astype(gradient.dtype) * gradient
v = self.momentum * state.get("v")
if self.dampening > 0:
v = (
state.get("v", (self.dampening / self.momentum) * gradient)
* self.momentum
)
v += (1 - self.dampening) * gradient
else:
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
v += gradient
if self.nesterov:
update = gradient + self.momentum * v
else:
update = v
state["v"] = v
return parameter - self.learning_rate * update
return parameter - self.learning_rate.astype(gradient.dtype) * update
class RMSprop(Optimizer):
@ -177,15 +207,17 @@ class RMSprop(Optimizer):
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
alpha = self.alpha
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = state["v"]
v = alpha * v + (1 - alpha) * mx.square(gradient)
state["v"] = v
@ -222,16 +254,17 @@ class Adagrad(Optimizer):
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adagrad parameter update and stores :math:`v` in the
optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = v + mx.square(gradient)
v = state["v"] + mx.square(gradient)
state["v"] = v
return parameter - lr * gradient / (mx.sqrt(v) + eps)
@ -274,17 +307,20 @@ class AdaDelta(Optimizer):
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
state["u"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdaDelta parameter update and stores :math:`v` and
:math:`u` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
rho = self.rho
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
u = state.get("u", mx.zeros_like(gradient))
v = state["v"]
u = state["u"]
v = rho * v + (1 - rho) * mx.square(gradient)
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
@ -329,17 +365,20 @@ class Adam(Optimizer):
self.betas = betas
self.eps = eps
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
eps = self.eps
m = state.get("m", gradient)
v = state.get("v", mx.square(gradient))
m = state["m"]
v = state["v"]
m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m
@ -385,15 +424,14 @@ class AdamW(Adam):
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdamW parameter update by modifying the parameters
passed into Adam.
"""
lr = self.learning_rate.astype(gradient.dtype)
return super().apply_single(
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
gradient, parameter * (1 - lr * self.weight_decay), state
)
@ -430,17 +468,20 @@ class Adamax(Adam):
f"Epsilon value should be >=0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adamax parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
eps = self.eps
m = state.get("m", mx.zeros_like(gradient))
v = state.get("v", mx.zeros_like(gradient))
m = state["m"]
v = state["v"]
m = b1 * m + (1 - b1) * gradient
v = mx.maximum(b2 * v, mx.abs(gradient))
@ -489,16 +530,18 @@ class Lion(Optimizer):
self.betas = betas
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Lion parameter update and stores :math:`m`
in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
weight_decay = self.weight_decay
m = state.get("m", gradient)
m = state["m"]
c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0:
@ -552,7 +595,8 @@ class Adafactor(Optimizer):
warmup_init: bool = False,
):
super().__init__()
self.learning_rate = learning_rate
if learning_rate is not None:
self.learning_rate = learning_rate
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
@ -562,14 +606,29 @@ class Adafactor(Optimizer):
self.relative_step = relative_step
self.warmup_init = warmup_init
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["step"] = 0
if parameter.ndim >= 2:
shape = parameter.shape
dtype = parameter.dtype
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
else:
state["exp_avg_sq"] = mx.zeros_like(parameter)
if self.beta_1 is not None:
state["exp_avg"] = mx.zeros_like(parameter)
def _compute_rms(self, inputs):
return mx.sqrt(mx.mean(mx.square(inputs)))
def _compute_learning_rate(self, step, parameter_rms):
relative_step_size = self.learning_rate
if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step))
else:
relative_step_size = self.learning_rate.astype(parameter_rms)
parameter_scale = 1.0
if self.scale_parameter:
@ -585,13 +644,11 @@ class Adafactor(Optimizer):
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adafactor parameter and state update."""
gradient_shape = gradient.shape
factored = len(gradient_shape) >= 2
step = state.get("step", 0) + 1
factored = gradient.ndim >= 2
step = state["step"] + 1
state["step"] = step
use_first_moment = self.beta_1 is not None
@ -601,15 +658,8 @@ class Adafactor(Optimizer):
update = mx.square(gradient) + self.eps[0]
if factored:
exp_avg_sq_row = state.get(
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
)
exp_avg_sq_col = state.get(
"exp_avg_sq_col",
mx.zeros(
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
),
)
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
(1 - beta_2) * mx.mean(update, axis=-1)
)
@ -621,7 +671,7 @@ class Adafactor(Optimizer):
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
update = update * gradient
else:
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
state["exp_avg_sq"] = exp_avg_sq
update = mx.rsqrt(exp_avg_sq) * gradient
@ -632,7 +682,7 @@ class Adafactor(Optimizer):
update = learning_rate * update
if use_first_moment:
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
exp_avg = state["exp_avg"]
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
state["exp_avg"] = exp_avg
update = exp_avg

View File

@ -2,6 +2,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <chrono>
#include "python/src/utils.h"
@ -13,13 +14,55 @@ using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::random;
class PyKeySequence {
public:
explicit PyKeySequence(uint64_t seed) {
state_.append(key(seed));
}
void seed(uint64_t seed) {
state_[0] = key(seed);
}
array next() {
auto out = split(py::cast<array>(state_[0]));
state_[0] = out.first;
return out.second;
}
py::list state() {
return state_;
}
void release() {
py::gil_scoped_acquire gil;
state_.release().dec_ref();
}
private:
py::list state_;
};
PyKeySequence& default_key() {
auto get_current_time_seed = []() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
};
static PyKeySequence ks(get_current_time_seed());
return ks;
}
void init_random(py::module_& parent_module) {
auto m = parent_module.def_submodule(
"random",
"mlx.core.random: functionality related to random number generation");
m.attr("state") = default_key().state();
m.def(
"seed",
&seed,
[](uint64_t seed) { default_key().seed(seed); },
"seed"_a,
R"pbdoc(
Seed the global PRNG.
@ -62,8 +105,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return uniform(
to_array(low),
to_array(high),
@ -101,11 +145,11 @@ void init_random(py::module_& parent_module) {
std::optional<Dtype> type,
float loc,
float scale,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return normal(shape, type.value_or(float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"dtype"_a = std::optional{float32},
"loc"_a = 0.0,
@ -131,8 +175,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& high,
const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return randint(
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
},
@ -163,8 +208,9 @@ void init_random(py::module_& parent_module) {
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto p = to_array(p_);
if (shape.has_value()) {
return bernoulli(p, shape.value(), key, s);
@ -199,8 +245,9 @@ void init_random(py::module_& parent_module) {
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
auto lower = to_array(lower_);
auto upper = to_array(upper_);
auto t = type.value_or(float32);
@ -239,8 +286,9 @@ void init_random(py::module_& parent_module) {
"gumbel",
[](const std::vector<int>& shape,
std::optional<Dtype> type,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return gumbel(shape, type.value_or(float32), key, s);
},
"shape"_a = std::vector<int>{},
@ -267,8 +315,9 @@ void init_random(py::module_& parent_module) {
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<int> num_samples,
const std::optional<array>& key,
const std::optional<array>& key_,
StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
if (shape.has_value() && num_samples.has_value()) {
throw std::invalid_argument(
"[categorical] At most one of shape or num_samples can be specified.");
@ -309,4 +358,7 @@ void init_random(py::module_& parent_module) {
Returns:
array: The ``shape``-sized output array with type ``uint32``.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { default_key().release(); }));
}

View File

@ -135,6 +135,64 @@ py::object tree_map(
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
@ -495,9 +553,15 @@ std::unordered_map<size_t, py::object>& tree_cache() {
struct PyCompiledFun {
py::function fun;
size_t fun_id;
py::object captured_inputs;
py::object captured_outputs;
size_t num_outputs{0};
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const py::function& fun, py::object inputs, py::object outputs)
: fun(fun),
fun_id(reinterpret_cast<size_t>(fun.ptr())),
captured_inputs(inputs),
captured_outputs(outputs) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
@ -505,23 +569,61 @@ struct PyCompiledFun {
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0;
captured_inputs = std::move(other.captured_inputs);
captured_outputs = std::move(other.captured_outputs);
num_outputs = other.num_outputs;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function and flatten the outputs
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(this->fun(*tree_unflatten(args, a))), true);
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
std::vector<array> trace_captures;
if (!py::isinstance<py::none>(captured_inputs)) {
flat_in_captures = tree_flatten(captured_inputs, false);
trace_captures.insert(
trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
tree_fill(captured_inputs, trace_captures);
}
tree_cache().insert({this->fun_id, py_outputs});
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(fun(*tree_unflatten(args, a))), false);
tree_cache().insert({fun_id, py_outputs});
num_outputs = outputs.size();
if (!py::isinstance<py::none>(captured_outputs)) {
auto flat_out_captures = tree_flatten(captured_outputs, false);
outputs.insert(
outputs.end(),
std::make_move_iterator(flat_out_captures.begin()),
std::make_move_iterator(flat_out_captures.end()));
}
// Replace tracers with originals in captured inputs
if (!py::isinstance<py::none>(captured_inputs)) {
tree_replace(captured_inputs, trace_captures, flat_in_captures);
}
return outputs;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto inputs = tree_flatten(args, false);
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_in_captures.begin()),
std::make_move_iterator(flat_in_captures.end()));
}
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
if (!py::isinstance<py::none>(captured_outputs)) {
std::vector<array> captures(
std::make_move_iterator(outputs.begin() + num_outputs),
std::make_move_iterator(outputs.end()));
tree_fill(captured_outputs, captures);
}
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
@ -534,6 +636,8 @@ struct PyCompiledFun {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
fun.release().dec_ref();
captured_inputs.release().dec_ref();
captured_outputs.release().dec_ref();
}
};
@ -601,7 +705,7 @@ void init_transforms(py::module_& m) {
m.def(
"eval",
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
std::vector<array> arrays = tree_flatten(args, false);
{
py::gil_scoped_release nogil;
eval(arrays);
@ -615,8 +719,8 @@ void init_transforms(py::module_& m) {
Args:
*args (arrays or trees of arrays): Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
an :class:`array`.
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
arrays are ignored.
)pbdoc");
m.def(
"jvp",
@ -859,10 +963,14 @@ void init_transforms(py::module_& m) {
"file"_a);
m.def(
"compile",
[](const py::function& fun) {
return py::cpp_function(PyCompiledFun{fun});
[](const py::function& fun,
const py::object& inputs,
const py::object& outputs) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs});
},
"fun"_a,
"inputs"_a = std::nullopt,
"outputs"_a = std::nullopt,
R"pbdoc(
compile(fun: function) -> function
@ -872,6 +980,16 @@ void init_transforms(py::module_& m) {
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during
the function compilation along with the inputs to ``fun``. The ``inputs``
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
lists, dictionaries, or arrays. Leaf nodes that are not
:obj:`array` are ignored. Default: ``None``
outputs (list or dict, optional): These outputs will be captured and
updated in a compiled function. The ``outputs`` can be a
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
Default: ``None``
Returns:
function: A compiled function which has the same input arguments

View File

@ -2,6 +2,7 @@
import io
import unittest
from functools import partial
import mlx.core as mx
import mlx_tests
@ -301,6 +302,85 @@ class TestCompile(mlx_tests.MLXTestCase):
cdfdx = mx.grad(outer)(x)
self.assertTrue(mx.allclose(dfdx, cdfdx))
def test_compile_capture(self):
# Test update captured state outside compiled function
state = {"y": mx.array(2)}
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state["y"]
return x
test_state(mx.array(1))
# Check the state is unchanged
self.assertEqual(state["y"], 2)
# Check the udpated state is used
state["y"] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture list
state = [mx.array(2)]
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Capture tuple of list
state = ([mx.array(2)],)
@partial(mx.compile, inputs=state)
def test_state(x):
x = x + state[0][0]
return x
out = test_state(mx.array(1))
self.assertEqual(out.item(), 3)
state[0][0] = mx.array(3)
out = test_state(mx.array(1))
self.assertEqual(out.item(), 4)
# Test state updated inside compiled function
state = {}
@partial(mx.compile, outputs=state)
def test_state(x):
state["y"] = x + 3
return mx.abs(x)
test_state(mx.array(-1))
self.assertEqual(state["y"].item(), 2)
# Test state changed inside compiled function
# triggers recompile
state = {}
@partial(mx.compile, inputs=state, outputs=state)
def test_state(x):
y = state.get("y", mx.array(0))
state["y"] = x + y
return x + 2 * y
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 1)
test_state(mx.array(1))
self.assertEqual(state["y"].item(), 2)
def test_compile_rng(self):
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def fun():
return mx.random.uniform(shape=(10, 10))
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
if __name__ == "__main__":
unittest.main()

View File

@ -24,6 +24,14 @@ class TestEval(mlx_tests.MLXTestCase):
y = dfun_dx(mx.array(1.0))
self.assertEqual(y.item(), 6.0)
def test_eval_mixed(self):
x = mx.array(1) + 1 + 1
y = 0
z = "hello"
state = [x, y, z]
mx.eval(state)
self.assertEqual(x.item(), 3)
if __name__ == "__main__":
unittest.main()

View File

@ -130,6 +130,11 @@ class TestBase(mlx_tests.MLXTestCase):
]
)
def test_module_state(self):
m = nn.Linear(10, 1)
m.state["hello"] = "world"
self.assertEqual(m.state["hello"], "world")
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):

View File

@ -2,47 +2,209 @@
import inspect
import unittest
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
from mlx.utils import tree_flatten, tree_map
def get_all_optimizers():
classes = dict()
for name, obj in inspect.getmembers(opt):
if inspect.isclass(obj):
if obj.__name__ not in ["OptimizerState", "Optimizer"]:
if obj.__name__ not in ["Optimizer"]:
classes[name] = obj
return classes
def tree_equal(fn, *args):
return all(v for _, v in tree_flatten(tree_map(fn, *args)))
optimizers_dict = get_all_optimizers()
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizer_state(self):
optim = opt.SGD(0.1)
optim.state["hello"] = "world"
self.assertEqual(optim.state["hello"], "world")
optim.state = {0: 1}
self.assertEqual(optim.state, {0: 1})
def test_optimizers(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(
lambda x, y: x.shape == y.shape, params, update
)
equal_shape = tree_map(lambda x, y: x.shape == y.shape, params, update)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
def test_types_conserved(self):
params = {"w": mx.ones((5, 5), mx.float16)}
grads = tree_map(lambda x: mx.ones_like(x), params)
for optim_class in optimizers_dict.values():
optim = optim_class(0.1)
update = optim.apply_gradients(grads, params)
self.assertEqual(update["w"].dtype, mx.float16)
def test_sgd(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(lambda g, s: mx.array_equal(s["v"], g), grads, optim.state)
)
def test_rmsprop(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.RMSprop(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Implicit init
alpha = 0.99
optim = opt.RMSprop(learning_rate=1e-2, alpha=alpha)
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(
lambda g, s: mx.allclose(s["v"], (1 - alpha) * g), grads, optim.state
)
)
def test_adagrad(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Adagrad(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adadelta(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.AdaDelta(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["u"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adam(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
for optimizer in [opt.Adam, opt.AdamW, opt.Adamax]:
optim = optimizer(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_lion(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Lion(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)
def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
@ -51,12 +213,86 @@ class TestOptimizers(mlx_tests.MLXTestCase):
x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
optimizer = opt.Adafactor()
optimizer.init(x)
for _ in range(2):
xp = optimizer.apply_single(grad, x, optimizer.state)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)
def test_compiled_optimizer(self):
model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10))
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
orig_params = model.parameters()
def loss(model, x):
return model(x).sum()
# Uncompiled version
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
uncompiled_params = model.parameters()
# Pure version
def loss(params, x):
model.update(params)
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
@mx.compile
def step(params, opt_state, x):
grad = mx.grad(loss)(params, x)
optim.state = opt_state
params = optim.apply_gradients(grad, params)
return params, optim.state
optim.init(model.parameters())
pure_params, _ = step(model.parameters(), optim.state, x)
self.assertTrue(mx.allclose(pure_params["weight"], uncompiled_params["weight"]))
self.assertTrue(mx.allclose(pure_params["bias"], uncompiled_params["bias"]))
# Impure version
def loss(model, x):
return model(x).sum()
model.update(orig_params)
optim = opt.SGD(learning_rate=1e-2, momentum=0.9)
state = [model.state, optim.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
_, grad = nn.value_and_grad(model, loss)(model, x)
optim.update(model, grad)
step(x)
impure_params = model.parameters()
self.assertTrue(
mx.allclose(impure_params["weight"], uncompiled_params["weight"])
)
self.assertTrue(mx.allclose(impure_params["bias"], uncompiled_params["bias"]))
def test_update_lr_compiled(self):
params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params)
optim = opt.SGD(-1.0)
@partial(mx.compile, inputs=optim.state)
def update(grads):
return optim.apply_gradients(grads, params)
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 2.0)))
optim.learning_rate = -2.0
result = update(grads)
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
if __name__ == "__main__":
unittest.main()