mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Add optimizers (AdaMax, AdaDelta, RMSprop) and ordering optimizer classes (#142)
* Add AdaMax, AdaDelta, RMSprop
This commit is contained in:
parent
a67bbfe745
commit
eebd7c275d
@ -38,4 +38,9 @@ model's parameters and the **optimizer state**.
|
|||||||
OptimizerState
|
OptimizerState
|
||||||
Optimizer
|
Optimizer
|
||||||
SGD
|
SGD
|
||||||
|
RMSprop
|
||||||
|
Adagrad
|
||||||
|
AdaDelta
|
||||||
Adam
|
Adam
|
||||||
|
AdamW
|
||||||
|
Adamax
|
||||||
|
@ -82,15 +82,15 @@ class SGD(Optimizer):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
v_{t+1} &= \mu v_t + g_t \\
|
v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
|
||||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning :math:`\lambda` for the update
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
momentum (float, optional): The momentum strength :math:`\mu` (default: 0)
|
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
||||||
weight_decay (float, optional): The weight decay (L2 penalty) (default: 0)
|
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
||||||
dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0)
|
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
||||||
nesterov (bool, optional): Enables Nesterov momentum (default: False)
|
nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -140,20 +140,181 @@ class SGD(Optimizer):
|
|||||||
return parameter - self.learning_rate * update
|
return parameter - self.learning_rate * update
|
||||||
|
|
||||||
|
|
||||||
|
class RMSprop(Optimizer):
|
||||||
|
r"""Implementation of the RMSprop optimizer [1].
|
||||||
|
|
||||||
|
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
|
||||||
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
|
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
||||||
|
Default: ``0.99``
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
||||||
|
to improve numerical stability. Default: ``1e-8``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.alpha = alpha
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if self.alpha < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
|
||||||
|
)
|
||||||
|
if self.eps < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
||||||
|
lr = self.learning_rate
|
||||||
|
alpha = self.alpha
|
||||||
|
eps = self.eps
|
||||||
|
|
||||||
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
|
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Adagrad(Optimizer):
|
||||||
|
r"""Implementation of the Adagrad optimizer [1].
|
||||||
|
|
||||||
|
Our Adagrad implementation follows the original paper. In detail,
|
||||||
|
|
||||||
|
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
||||||
|
for online learning and stochastic optimization. JMLR 2011.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
v_{t+1} &= v_t + g_t^2 \\
|
||||||
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if self.eps < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
||||||
|
optimizer state."""
|
||||||
|
lr = self.learning_rate
|
||||||
|
eps = self.eps
|
||||||
|
|
||||||
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
|
v = v + mx.square(gradient)
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaDelta(Optimizer):
|
||||||
|
r"""Implementation of the AdaDelta optimizer with learning rate[1].
|
||||||
|
|
||||||
|
Our AdaDelta implementation follows the original paper. In detail,
|
||||||
|
|
||||||
|
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
|
||||||
|
\Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
|
||||||
|
u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
|
||||||
|
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
|
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
||||||
|
running average of squared gradients. Default: ``0.9``
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
||||||
|
numerical stability. Ddefault: `1e-8`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.rho = rho
|
||||||
|
self.eps = eps
|
||||||
|
if self.rho < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"AdaDelta rho should be >=0, {self.rho} was provided instead"
|
||||||
|
)
|
||||||
|
if self.eps < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_single(
|
||||||
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
|
):
|
||||||
|
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
||||||
|
:math:`u` in the optimizer state."""
|
||||||
|
lr = self.learning_rate
|
||||||
|
rho = self.rho
|
||||||
|
eps = self.eps
|
||||||
|
|
||||||
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
|
u = state.get("s", mx.zeros_like(gradient))
|
||||||
|
|
||||||
|
v = rho * v + (1 - rho) * mx.square(gradient)
|
||||||
|
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
||||||
|
u = rho * u + (1 - rho) * mx.square(d)
|
||||||
|
|
||||||
|
state["v"] = v
|
||||||
|
state["u"] = u
|
||||||
|
|
||||||
|
return parameter - lr * d
|
||||||
|
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
r"""Implementation of the Adam optimizer [1].
|
r"""Implementation of the Adam optimizer [1].
|
||||||
|
|
||||||
Our Adam implementation follows the original paper and omits the bias
|
Our Adam implementation follows the original paper and omits the bias
|
||||||
correction in the first and second moment estimates. In detail,
|
correction in the first and second moment estimates. In detail,
|
||||||
|
|
||||||
|
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||||
|
optimization. ICLR 2015.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
||||||
|
|
||||||
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
Args:
|
||||||
optimization. ICLR 2015.
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -189,7 +350,10 @@ class AdamW(Adam):
|
|||||||
|
|
||||||
Following the above convention, in contrast with [1], we do not use bias
|
Following the above convention, in contrast with [1], we do not use bias
|
||||||
correction in the first and second moments for AdamW. We update the weights
|
correction in the first and second moments for AdamW. We update the weights
|
||||||
with a weight_decay (λ) value:
|
with a weight_decay (:math:`\lambda`) value:
|
||||||
|
|
||||||
|
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
||||||
|
regularization. ICLR 2019.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
@ -197,8 +361,15 @@ class AdamW(Adam):
|
|||||||
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
||||||
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
||||||
|
|
||||||
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
Args:
|
||||||
regularization. ICLR 2019.
|
learning_rate (float): The learning rate :math:`\alpha`.
|
||||||
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
|
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
||||||
|
Default: ``0``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -223,45 +394,51 @@ class AdamW(Adam):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Adagrad(Optimizer):
|
class Adamax(Adam):
|
||||||
r"""Implementation of the Adagrad optimizer [1].
|
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
|
||||||
|
on the infinity norm [1].
|
||||||
|
|
||||||
Our Adagrad implementation follows the original paper. In detail,
|
Our Adam implementation follows the original paper and omits the bias
|
||||||
|
correction in the first and second moment estimates. In detail,
|
||||||
|
|
||||||
|
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
||||||
|
optimization. ICLR 2015.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
v_{t+1} &= v_t + g_t^2 \\
|
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
||||||
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1} + \epsilon}}
|
v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
|
||||||
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
||||||
|
|
||||||
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
Args:
|
||||||
for online learning and stochastic optimization. JMLR 2011.
|
learning_rate (float): The learning rate :math:`\lambda`.
|
||||||
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
def __init__(
|
||||||
super().__init__()
|
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
||||||
|
):
|
||||||
self.learning_rate = learning_rate
|
super().__init__(learning_rate, betas, eps)
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if self.learning_rate < 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Adagrad learning rate should be >=0, {self.learning_rate} was provided instead"
|
|
||||||
)
|
|
||||||
if self.eps < 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_single(
|
def apply_single(
|
||||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
):
|
):
|
||||||
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
"""Performs the Adamax parameter update and stores :math:`v` and
|
||||||
optimizer state."""
|
:math:`m` in the optimizer state."""
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
|
b1, b2 = self.betas
|
||||||
eps = self.eps
|
eps = self.eps
|
||||||
|
|
||||||
|
m = state.get("m", mx.zeros_like(gradient))
|
||||||
v = state.get("v", mx.zeros_like(gradient))
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
v = v + mx.square(gradient)
|
|
||||||
|
m = b1 * m + (1 - b1) * gradient
|
||||||
|
v = mx.maximum(b2 * v, mx.abs(gradient))
|
||||||
|
state["m"] = m
|
||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
return parameter - lr * m / (v + eps)
|
||||||
|
Loading…
Reference in New Issue
Block a user