mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor
This commit is contained in:
@@ -210,14 +210,19 @@ class RMSprop(Optimizer):
|
||||
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
||||
Default: ``0.99``
|
||||
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
||||
to improve numerical stability. Default: ``1e-8``
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
alpha: float = 0.99,
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
@@ -264,12 +269,16 @@ class Adagrad(Optimizer):
|
||||
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
eps (float, optional): The term :math:`\epsilon` added to the
|
||||
denominator to improve numerical stability. Default: ``1e-8``
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
@@ -311,14 +320,19 @@ class AdaDelta(Optimizer):
|
||||
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
||||
running average of squared gradients. Default: ``0.9``
|
||||
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
||||
numerical stability. Default: `1e-8`
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
rho: float = 0.9,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
@@ -374,7 +388,7 @@ class Adam(Optimizer):
|
||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
betas (Tuple[float, float], optional): The coefficients
|
||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||
gradient and its square. Default: ``(0.9, 0.999)``
|
||||
@@ -383,7 +397,10 @@ class Adam(Optimizer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -430,7 +447,7 @@ class AdamW(Adam):
|
||||
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\alpha`.
|
||||
learning_rate (float or callable): The learning rate :math:`\alpha`.
|
||||
betas (Tuple[float, float], optional): The coefficients
|
||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||
gradient and its square. Default: ``(0.9, 0.999)``
|
||||
@@ -442,7 +459,7 @@ class AdamW(Adam):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
@@ -477,7 +494,7 @@ class Adamax(Adam):
|
||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\lambda`.
|
||||
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||
betas (Tuple[float, float], optional): The coefficients
|
||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||
gradient and its square. Default: ``(0.9, 0.999)``
|
||||
@@ -486,7 +503,10 @@ class Adamax(Adam):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.999],
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super().__init__(learning_rate, betas, eps)
|
||||
if not 0.0 <= eps:
|
||||
@@ -537,7 +557,7 @@ class Lion(Optimizer):
|
||||
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning rate :math:`\eta`.
|
||||
learning_rate (float or callable): The learning rate :math:`\eta`.
|
||||
betas (Tuple[float, float], optional): The coefficients
|
||||
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
||||
momentum and update direction. Default: ``(0.9, 0.99)``
|
||||
@@ -546,7 +566,7 @@ class Lion(Optimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
betas: List[float] = [0.9, 0.99],
|
||||
weight_decay: float = 0.0,
|
||||
):
|
||||
@@ -583,7 +603,8 @@ class Adafactor(Optimizer):
|
||||
<https://arxiv.org/abs/1804.04235>`_
|
||||
|
||||
Args:
|
||||
learning_rate (float, optional): The learning rate. Default: ``None``.
|
||||
learning_rate (float or callable, optional): The learning rate.
|
||||
Default: ``None``.
|
||||
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
||||
added to the square of the gradients to improve numerical
|
||||
stability and the second term :math:`\epsilon_2` is used for
|
||||
@@ -610,7 +631,7 @@ class Adafactor(Optimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Optional[float] = None,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
|
||||
eps: Tuple[float, float] = (1e-30, 1e-3),
|
||||
clip_threshold: float = 1.0,
|
||||
decay_rate: float = -0.8,
|
||||
|
Reference in New Issue
Block a user