Compile with capture (#629)

* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-07 17:29:22 -08:00
committed by GitHub
parent e5e816a5ef
commit 1b97b2958b
13 changed files with 723 additions and 157 deletions

View File

@@ -66,6 +66,19 @@ class Module(dict):
"""Boolean indicating if the model is in training mode."""
return self._training
@property
def state(self):
"""The module's state dictionary
The module's state dictionary contains any attribute set on the
module including parameters in :meth:`Module.parameters`
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
a reference to the module's state. Updates to it will be reflected in
the original module.
"""
return self
def _extra_repr(self):
return ""

View File

@@ -7,39 +7,14 @@ import mlx.core as mx
from mlx.utils import tree_map
class OptimizerState(dict):
"""The optimizer state implements a recursively defined
:class:`collections.defaultdict`, namely a missing key in an optimizer
state is an :class:`OptimizerState`.
.. note::
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
the key to the ``default`` value if the ``key`` was not present in the
dictionary.
"""
def __getitem__(self, key):
if key not in self:
self[key] = OptimizerState()
return super().__getitem__(key)
def get(self, key, default):
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
if key not in self:
self[key] = default
return super().__getitem__(key)
class Optimizer:
"""The base class for all optimizers. It allows us to implement an
optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes:
state (OptimizerState): It holds the optimizer's state dictionary.
"""
def __init__(self):
self.state = OptimizerState()
self._initialized = False
self._state = {}
def update(self, model: "mlx.nn.Module", gradients: dict):
"""Apply the gradients to the parameters of the model and update the
@@ -52,7 +27,41 @@ class Optimizer:
"""
model.update(self.apply_gradients(gradients, model))
def apply_gradients(self, gradients: dict, model: dict):
def init(self, parameters: dict):
"""Initialize the optimizer's state
This function can be used to initialize optimizers which have state
(like momentum in :class:`SGD`). Using this method is optional as the
optimizer will initialize itself if the state is not yet set. However,
there are some cases where explicit initialization is useful in order
to have access to the :attr:`Optimizer.state` before the first call to
:meth:`Optimizer.update`.
Args:
model (dict): A Python tree of parameters.
Example:
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
>>> model = nn.Linear(2, 2)
>>> optimizer.init(model.trainable_parameters())
>>> optimizer.state
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)
self._initialized = True
def init_single(self, parameter: mx.array, state: dict):
"""To be extended by the children classes to implement each optimizer's
state initialization.
Args:
parameter (mx.array): A single parameter that will be optimized.
"""
raise NotImplementedError()
def apply_gradients(self, gradients: dict, parameters: dict):
"""Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via
@@ -61,19 +70,41 @@ class Optimizer:
Args:
gradients (dict): A Python tree of gradients.
model (dict): A Python tree of parameters. It can be a superset of
the gradients. In that case the returned python tree
will be of the same structure as the gradients.
parameters (dict): A Python tree of parameters. It can be a
superset of the gradients. In that case the returned python
tree will be of the same structure as the gradients.
"""
return tree_map(self.apply_single, gradients, model, self.state)
if not self._initialized:
self.init(gradients)
return tree_map(self.apply_single, gradients, parameters, self.state)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""To be extended by the children classes to implement each optimizer's
update."""
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""To be extended by derived classes to implement the optimizer's update.
Args:
gradient (mx.array): The ``parameter`` gradient.
parameter (mx.array): The ``parameter`` to update.
state (dict): The optimizer's state.
"""
raise NotImplementedError()
@property
def state(self):
"""The optimizer's state dictionary."""
return self._state
@state.setter
def state(self, state: dict):
self._state = state
@property
def learning_rate(self):
return self.state["learning_rate"]
@learning_rate.setter
def learning_rate(self, learning_rate: mx.array):
self.state["learning_rate"] = mx.array(learning_rate)
class SGD(Optimizer):
r"""The stochastic gradient descent optimizer.
@@ -113,9 +144,11 @@ class SGD(Optimizer):
self.dampening = dampening
self.nesterov = nesterov
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
@@ -123,24 +156,21 @@ class SGD(Optimizer):
gradient += self.weight_decay * parameter
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
return parameter - self.learning_rate.astype(gradient.dtype) * gradient
v = self.momentum * state.get("v")
if self.dampening > 0:
v = (
state.get("v", (self.dampening / self.momentum) * gradient)
* self.momentum
)
v += (1 - self.dampening) * gradient
else:
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
v += gradient
if self.nesterov:
update = gradient + self.momentum * v
else:
update = v
state["v"] = v
return parameter - self.learning_rate * update
return parameter - self.learning_rate.astype(gradient.dtype) * update
class RMSprop(Optimizer):
@@ -177,15 +207,17 @@ class RMSprop(Optimizer):
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
alpha = self.alpha
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = state["v"]
v = alpha * v + (1 - alpha) * mx.square(gradient)
state["v"] = v
@@ -222,16 +254,17 @@ class Adagrad(Optimizer):
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adagrad parameter update and stores :math:`v` in the
optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
v = v + mx.square(gradient)
v = state["v"] + mx.square(gradient)
state["v"] = v
return parameter - lr * gradient / (mx.sqrt(v) + eps)
@@ -274,17 +307,20 @@ class AdaDelta(Optimizer):
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
state["u"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdaDelta parameter update and stores :math:`v` and
:math:`u` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
rho = self.rho
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
u = state.get("u", mx.zeros_like(gradient))
v = state["v"]
u = state["u"]
v = rho * v + (1 - rho) * mx.square(gradient)
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
@@ -329,17 +365,20 @@ class Adam(Optimizer):
self.betas = betas
self.eps = eps
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adam parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
eps = self.eps
m = state.get("m", gradient)
v = state.get("v", mx.square(gradient))
m = state["m"]
v = state["v"]
m = b1 * m + (1 - b1) * gradient
v = b2 * v + (1 - b2) * mx.square(gradient)
state["m"] = m
@@ -385,15 +424,14 @@ class AdamW(Adam):
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdamW parameter update by modifying the parameters
passed into Adam.
"""
lr = self.learning_rate.astype(gradient.dtype)
return super().apply_single(
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
gradient, parameter * (1 - lr * self.weight_decay), state
)
@@ -430,17 +468,20 @@ class Adamax(Adam):
f"Epsilon value should be >=0, {self.eps} was provided instead"
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
state["v"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adamax parameter update and stores :math:`v` and
:math:`m` in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
eps = self.eps
m = state.get("m", mx.zeros_like(gradient))
v = state.get("v", mx.zeros_like(gradient))
m = state["m"]
v = state["v"]
m = b1 * m + (1 - b1) * gradient
v = mx.maximum(b2 * v, mx.abs(gradient))
@@ -489,16 +530,18 @@ class Lion(Optimizer):
self.betas = betas
self.weight_decay = weight_decay
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Lion parameter update and stores :math:`m`
in the optimizer state."""
lr = self.learning_rate
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
weight_decay = self.weight_decay
m = state.get("m", gradient)
m = state["m"]
c = b1 * m + (1 - b1) * gradient
state["m"] = b2 * m + (1 - b2) * gradient
if weight_decay > 0:
@@ -552,7 +595,8 @@ class Adafactor(Optimizer):
warmup_init: bool = False,
):
super().__init__()
self.learning_rate = learning_rate
if learning_rate is not None:
self.learning_rate = learning_rate
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
@@ -562,14 +606,29 @@ class Adafactor(Optimizer):
self.relative_step = relative_step
self.warmup_init = warmup_init
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["step"] = 0
if parameter.ndim >= 2:
shape = parameter.shape
dtype = parameter.dtype
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
else:
state["exp_avg_sq"] = mx.zeros_like(parameter)
if self.beta_1 is not None:
state["exp_avg"] = mx.zeros_like(parameter)
def _compute_rms(self, inputs):
return mx.sqrt(mx.mean(mx.square(inputs)))
def _compute_learning_rate(self, step, parameter_rms):
relative_step_size = self.learning_rate
if self.relative_step:
min_step = 1e-6 * step if self.warmup_init else 1e-2
relative_step_size = min(min_step, 1 / math.sqrt(step))
else:
relative_step_size = self.learning_rate.astype(parameter_rms)
parameter_scale = 1.0
if self.scale_parameter:
@@ -585,13 +644,11 @@ class Adafactor(Optimizer):
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
)
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Adafactor parameter and state update."""
gradient_shape = gradient.shape
factored = len(gradient_shape) >= 2
step = state.get("step", 0) + 1
factored = gradient.ndim >= 2
step = state["step"] + 1
state["step"] = step
use_first_moment = self.beta_1 is not None
@@ -601,15 +658,8 @@ class Adafactor(Optimizer):
update = mx.square(gradient) + self.eps[0]
if factored:
exp_avg_sq_row = state.get(
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
)
exp_avg_sq_col = state.get(
"exp_avg_sq_col",
mx.zeros(
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
),
)
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
(1 - beta_2) * mx.mean(update, axis=-1)
)
@@ -621,7 +671,7 @@ class Adafactor(Optimizer):
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
update = update * gradient
else:
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
state["exp_avg_sq"] = exp_avg_sq
update = mx.rsqrt(exp_avg_sq) * gradient
@@ -632,7 +682,7 @@ class Adafactor(Optimizer):
update = learning_rate * update
if use_first_moment:
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
exp_avg = state["exp_avg"]
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
state["exp_avg"] = exp_avg
update = exp_avg