Some fixes in cache / thread safety (#777)

* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
This commit is contained in:
Awni Hannun
2024-03-05 13:30:50 -08:00
committed by GitHub
parent 859ae15a54
commit cbcf44a4ca
4 changed files with 60 additions and 41 deletions

View File

@@ -210,14 +210,19 @@ class RMSprop(Optimizer):
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
alpha (float, optional): The smoothing constant :math:`\alpha`.
Default: ``0.99``
eps (float, optional): The term :math:`\epsilon` added to the denominator
to improve numerical stability. Default: ``1e-8``
"""
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
alpha: float = 0.99,
eps: float = 1e-8,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
@@ -264,12 +269,16 @@ class Adagrad(Optimizer):
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8``
"""
def __init__(self, learning_rate: float, eps: float = 1e-8):
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
eps: float = 1e-8,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
@@ -311,14 +320,19 @@ class AdaDelta(Optimizer):
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
rho (float, optional): The coefficient :math:`\rho` used for computing a
running average of squared gradients. Default: ``0.9``
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
numerical stability. Default: `1e-8`
"""
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
rho: float = 0.9,
eps: float = 1e-6,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
@@ -374,7 +388,7 @@ class Adam(Optimizer):
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
@@ -383,7 +397,10 @@ class Adam(Optimizer):
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
):
super().__init__()
@@ -430,7 +447,7 @@ class AdamW(Adam):
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\alpha`.
learning_rate (float or callable): The learning rate :math:`\alpha`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
@@ -442,7 +459,7 @@ class AdamW(Adam):
def __init__(
self,
learning_rate: float,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
weight_decay: float = 0.01,
@@ -477,7 +494,7 @@ class Adamax(Adam):
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
Args:
learning_rate (float): The learning rate :math:`\lambda`.
learning_rate (float or callable): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)``
@@ -486,7 +503,10 @@ class Adamax(Adam):
"""
def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
):
super().__init__(learning_rate, betas, eps)
if not 0.0 <= eps:
@@ -537,7 +557,7 @@ class Lion(Optimizer):
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
Args:
learning_rate (float): The learning rate :math:`\eta`.
learning_rate (float or callable): The learning rate :math:`\eta`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing the gradient
momentum and update direction. Default: ``(0.9, 0.99)``
@@ -546,7 +566,7 @@ class Lion(Optimizer):
def __init__(
self,
learning_rate: float,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.99],
weight_decay: float = 0.0,
):
@@ -583,7 +603,8 @@ class Adafactor(Optimizer):
<https://arxiv.org/abs/1804.04235>`_
Args:
learning_rate (float, optional): The learning rate. Default: ``None``.
learning_rate (float or callable, optional): The learning rate.
Default: ``None``.
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
added to the square of the gradients to improve numerical
stability and the second term :math:`\epsilon_2` is used for
@@ -610,7 +631,7 @@ class Adafactor(Optimizer):
def __init__(
self,
learning_rate: Optional[float] = None,
learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
eps: Tuple[float, float] = (1e-30, 1e-3),
clip_threshold: float = 1.0,
decay_rate: float = -0.8,