mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Compile with capture (#629)
* Simple kernel generation * Remove the generate kernel from graph_utils * fix multi-output with compile * fuse with stopgrad * v1 input, output capture in compile * cleanup tree update with visitor update * nit * remove todo * state for model, optional explicit init and more pure optimizer steps * move learning rate to state * add lr to opt state, some fixes in capture * fix optim * update tuple of containers as well * fix stream for compiled output * rng state for compile * nit * updates and comments --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -66,6 +66,19 @@ class Module(dict):
|
||||
"""Boolean indicating if the model is in training mode."""
|
||||
return self._training
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The module's state dictionary
|
||||
|
||||
The module's state dictionary contains any attribute set on the
|
||||
module including parameters in :meth:`Module.parameters`
|
||||
|
||||
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
|
||||
a reference to the module's state. Updates to it will be reflected in
|
||||
the original module.
|
||||
"""
|
||||
return self
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
|
@@ -7,39 +7,14 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class OptimizerState(dict):
|
||||
"""The optimizer state implements a recursively defined
|
||||
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
||||
state is an :class:`OptimizerState`.
|
||||
|
||||
.. note::
|
||||
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
||||
the key to the ``default`` value if the ``key`` was not present in the
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = OptimizerState()
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default):
|
||||
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
||||
if key not in self:
|
||||
self[key] = default
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
class Optimizer:
|
||||
"""The base class for all optimizers. It allows us to implement an
|
||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||
|
||||
Attributes:
|
||||
state (OptimizerState): It holds the optimizer's state dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = OptimizerState()
|
||||
self._initialized = False
|
||||
self._state = {}
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
@@ -52,7 +27,41 @@ class Optimizer:
|
||||
"""
|
||||
model.update(self.apply_gradients(gradients, model))
|
||||
|
||||
def apply_gradients(self, gradients: dict, model: dict):
|
||||
def init(self, parameters: dict):
|
||||
"""Initialize the optimizer's state
|
||||
|
||||
This function can be used to initialize optimizers which have state
|
||||
(like momentum in :class:`SGD`). Using this method is optional as the
|
||||
optimizer will initialize itself if the state is not yet set. However,
|
||||
there are some cases where explicit initialization is useful in order
|
||||
to have access to the :attr:`Optimizer.state` before the first call to
|
||||
:meth:`Optimizer.update`.
|
||||
|
||||
Args:
|
||||
model (dict): A Python tree of parameters.
|
||||
|
||||
Example:
|
||||
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
||||
>>> model = nn.Linear(2, 2)
|
||||
>>> optimizer.init(model.trainable_parameters())
|
||||
>>> optimizer.state
|
||||
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
|
||||
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
|
||||
"""
|
||||
self._state.update(tree_map(lambda x: {}, parameters))
|
||||
tree_map(self.init_single, parameters, self._state)
|
||||
self._initialized = True
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
state initialization.
|
||||
|
||||
Args:
|
||||
parameter (mx.array): A single parameter that will be optimized.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_gradients(self, gradients: dict, parameters: dict):
|
||||
"""Apply the gradients to the parameters and return the updated parameters.
|
||||
|
||||
Can be used to update a model via
|
||||
@@ -61,19 +70,41 @@ class Optimizer:
|
||||
|
||||
Args:
|
||||
gradients (dict): A Python tree of gradients.
|
||||
model (dict): A Python tree of parameters. It can be a superset of
|
||||
the gradients. In that case the returned python tree
|
||||
will be of the same structure as the gradients.
|
||||
parameters (dict): A Python tree of parameters. It can be a
|
||||
superset of the gradients. In that case the returned python
|
||||
tree will be of the same structure as the gradients.
|
||||
"""
|
||||
return tree_map(self.apply_single, gradients, model, self.state)
|
||||
if not self._initialized:
|
||||
self.init(gradients)
|
||||
return tree_map(self.apply_single, gradients, parameters, self.state)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""To be extended by the children classes to implement each optimizer's
|
||||
update."""
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""To be extended by derived classes to implement the optimizer's update.
|
||||
|
||||
Args:
|
||||
gradient (mx.array): The ``parameter`` gradient.
|
||||
parameter (mx.array): The ``parameter`` to update.
|
||||
state (dict): The optimizer's state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The optimizer's state dictionary."""
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.state["learning_rate"]
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, learning_rate: mx.array):
|
||||
self.state["learning_rate"] = mx.array(learning_rate)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""The stochastic gradient descent optimizer.
|
||||
@@ -113,9 +144,11 @@ class SGD(Optimizer):
|
||||
self.dampening = dampening
|
||||
self.nesterov = nesterov
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
|
||||
@@ -123,24 +156,21 @@ class SGD(Optimizer):
|
||||
gradient += self.weight_decay * parameter
|
||||
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * gradient
|
||||
|
||||
v = self.momentum * state.get("v")
|
||||
if self.dampening > 0:
|
||||
v = (
|
||||
state.get("v", (self.dampening / self.momentum) * gradient)
|
||||
* self.momentum
|
||||
)
|
||||
v += (1 - self.dampening) * gradient
|
||||
else:
|
||||
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
|
||||
v += gradient
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient + self.momentum * v
|
||||
else:
|
||||
update = v
|
||||
|
||||
state["v"] = v
|
||||
return parameter - self.learning_rate * update
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * update
|
||||
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
@@ -177,15 +207,17 @@ class RMSprop(Optimizer):
|
||||
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
alpha = self.alpha
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = state["v"]
|
||||
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
||||
state["v"] = v
|
||||
|
||||
@@ -222,16 +254,17 @@ class Adagrad(Optimizer):
|
||||
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = v + mx.square(gradient)
|
||||
v = state["v"] + mx.square(gradient)
|
||||
state["v"] = v
|
||||
|
||||
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
||||
@@ -274,17 +307,20 @@ class AdaDelta(Optimizer):
|
||||
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
state["u"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
||||
:math:`u` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
rho = self.rho
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
u = state.get("u", mx.zeros_like(gradient))
|
||||
v = state["v"]
|
||||
u = state["u"]
|
||||
|
||||
v = rho * v + (1 - rho) * mx.square(gradient)
|
||||
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
||||
@@ -329,17 +365,20 @@ class Adam(Optimizer):
|
||||
self.betas = betas
|
||||
self.eps = eps
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adam parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", gradient)
|
||||
v = state.get("v", mx.square(gradient))
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = b2 * v + (1 - b2) * mx.square(gradient)
|
||||
state["m"] = m
|
||||
@@ -385,15 +424,14 @@ class AdamW(Adam):
|
||||
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the AdamW parameter update by modifying the parameters
|
||||
passed into Adam.
|
||||
"""
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
return super().apply_single(
|
||||
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
||||
gradient, parameter * (1 - lr * self.weight_decay), state
|
||||
)
|
||||
|
||||
|
||||
@@ -430,17 +468,20 @@ class Adamax(Adam):
|
||||
f"Epsilon value should be >=0, {self.eps} was provided instead"
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adamax parameter update and stores :math:`v` and
|
||||
:math:`m` in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
eps = self.eps
|
||||
|
||||
m = state.get("m", mx.zeros_like(gradient))
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
m = state["m"]
|
||||
v = state["v"]
|
||||
|
||||
m = b1 * m + (1 - b1) * gradient
|
||||
v = mx.maximum(b2 * v, mx.abs(gradient))
|
||||
@@ -489,16 +530,18 @@ class Lion(Optimizer):
|
||||
self.betas = betas
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["m"] = mx.zeros_like(parameter)
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Lion parameter update and stores :math:`m`
|
||||
in the optimizer state."""
|
||||
lr = self.learning_rate
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
b1, b2 = self.betas
|
||||
weight_decay = self.weight_decay
|
||||
|
||||
m = state.get("m", gradient)
|
||||
m = state["m"]
|
||||
c = b1 * m + (1 - b1) * gradient
|
||||
state["m"] = b2 * m + (1 - b2) * gradient
|
||||
if weight_decay > 0:
|
||||
@@ -552,7 +595,8 @@ class Adafactor(Optimizer):
|
||||
warmup_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learning_rate = learning_rate
|
||||
if learning_rate is not None:
|
||||
self.learning_rate = learning_rate
|
||||
self.eps = eps
|
||||
self.clip_threshold = clip_threshold
|
||||
self.decay_rate = decay_rate
|
||||
@@ -562,14 +606,29 @@ class Adafactor(Optimizer):
|
||||
self.relative_step = relative_step
|
||||
self.warmup_init = warmup_init
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["step"] = 0
|
||||
if parameter.ndim >= 2:
|
||||
shape = parameter.shape
|
||||
dtype = parameter.dtype
|
||||
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
|
||||
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
|
||||
else:
|
||||
state["exp_avg_sq"] = mx.zeros_like(parameter)
|
||||
|
||||
if self.beta_1 is not None:
|
||||
state["exp_avg"] = mx.zeros_like(parameter)
|
||||
|
||||
def _compute_rms(self, inputs):
|
||||
return mx.sqrt(mx.mean(mx.square(inputs)))
|
||||
|
||||
def _compute_learning_rate(self, step, parameter_rms):
|
||||
relative_step_size = self.learning_rate
|
||||
if self.relative_step:
|
||||
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
||||
else:
|
||||
relative_step_size = self.learning_rate.astype(parameter_rms)
|
||||
|
||||
parameter_scale = 1.0
|
||||
if self.scale_parameter:
|
||||
@@ -585,13 +644,11 @@ class Adafactor(Optimizer):
|
||||
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
|
||||
)
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Adafactor parameter and state update."""
|
||||
gradient_shape = gradient.shape
|
||||
factored = len(gradient_shape) >= 2
|
||||
step = state.get("step", 0) + 1
|
||||
factored = gradient.ndim >= 2
|
||||
|
||||
step = state["step"] + 1
|
||||
state["step"] = step
|
||||
use_first_moment = self.beta_1 is not None
|
||||
|
||||
@@ -601,15 +658,8 @@ class Adafactor(Optimizer):
|
||||
update = mx.square(gradient) + self.eps[0]
|
||||
|
||||
if factored:
|
||||
exp_avg_sq_row = state.get(
|
||||
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
|
||||
)
|
||||
exp_avg_sq_col = state.get(
|
||||
"exp_avg_sq_col",
|
||||
mx.zeros(
|
||||
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
|
||||
),
|
||||
)
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
|
||||
(1 - beta_2) * mx.mean(update, axis=-1)
|
||||
)
|
||||
@@ -621,7 +671,7 @@ class Adafactor(Optimizer):
|
||||
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update = update * gradient
|
||||
else:
|
||||
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
update = mx.rsqrt(exp_avg_sq) * gradient
|
||||
@@ -632,7 +682,7 @@ class Adafactor(Optimizer):
|
||||
update = learning_rate * update
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
|
||||
state["exp_avg"] = exp_avg
|
||||
update = exp_avg
|
||||
|
Reference in New Issue
Block a user