diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index 7f5d3a067..b8e5cfea7 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -38,4 +38,9 @@ model's parameters and the **optimizer state**. OptimizerState Optimizer SGD + RMSprop + Adagrad + AdaDelta Adam + AdamW + Adamax diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index ae981c7f3..161d923f6 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -82,15 +82,15 @@ class SGD(Optimizer): .. math:: - v_{t+1} &= \mu v_t + g_t \\ + v_{t+1} &= \mu v_t + (1 - \tau) g_t \\ w_{t+1} &= w_t - \lambda v_{t+1} Args: - learning_rate (float): The learning :math:`\lambda` for the update - momentum (float, optional): The momentum strength :math:`\mu` (default: 0) - weight_decay (float, optional): The weight decay (L2 penalty) (default: 0) - dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0) - nesterov (bool, optional): Enables Nesterov momentum (default: False) + learning_rate (float): The learning rate :math:`\lambda`. + momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0`` + weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0`` + dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0`` + nesterov (bool, optional): Enables Nesterov momentum. Default: ``False`` """ def __init__( @@ -140,20 +140,181 @@ class SGD(Optimizer): return parameter - self.learning_rate * update +class RMSprop(Optimizer): + r"""Implementation of the RMSprop optimizer [1]. + + [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning + + .. math:: + + v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\ + w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} + + Args: + learning_rate (float): The learning rate :math:`\lambda`. + alpha (float, optional): The smoothing constant :math:`\alpha`. + Default: ``0.99`` + eps (float, optional): The term :math:`\epsilon` added to the denominator + to improve numerical stability. Default: ``1e-8`` + """ + + def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): + super().__init__() + + self.learning_rate = learning_rate + self.alpha = alpha + self.eps = eps + + if self.alpha < 0.0: + raise ValueError( + f"RMSprop alpha should be >=0, {self.alpha} was provided instead" + ) + if self.eps < 0.0: + raise ValueError( + f"RMSprop epsilon should be >0, {self.eps} was provided instead" + ) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state.""" + lr = self.learning_rate + alpha = self.alpha + eps = self.eps + + v = state.get("v", mx.zeros_like(gradient)) + v = alpha * v + (1 - alpha) * mx.square(gradient) + state["v"] = v + + return parameter - lr * gradient / (mx.sqrt(v) + eps) + + +class Adagrad(Optimizer): + r"""Implementation of the Adagrad optimizer [1]. + + Our Adagrad implementation follows the original paper. In detail, + + [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods + for online learning and stochastic optimization. JMLR 2011. + + .. math:: + + v_{t+1} &= v_t + g_t^2 \\ + w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} + + Args: + learning_rate (float): The learning rate :math:`\lambda`. + eps (float, optional): The term :math:`\epsilon` added to the + denominator to improve numerical stability. Default: ``1e-8`` + """ + + def __init__(self, learning_rate: float, eps: float = 1e-8): + super().__init__() + + self.learning_rate = learning_rate + self.eps = eps + + if self.eps < 0.0: + raise ValueError( + f"Adagrad epsilon should be >0, {self.eps} was provided instead" + ) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the Adagrad parameter update and stores :math:`v` in the + optimizer state.""" + lr = self.learning_rate + eps = self.eps + + v = state.get("v", mx.zeros_like(gradient)) + v = v + mx.square(gradient) + state["v"] = v + + return parameter - lr * gradient / (mx.sqrt(v) + eps) + + +class AdaDelta(Optimizer): + r"""Implementation of the AdaDelta optimizer with learning rate[1]. + + Our AdaDelta implementation follows the original paper. In detail, + + [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701. + + .. math:: + + v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\ + \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\ + u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\ + w_{t+1} &= w_t - \lambda \Delta w_{t+1} + + Args: + learning_rate (float): The learning rate :math:`\lambda`. + rho (float, optional): The coefficient :math:`\rho` used for computing a + running average of squared gradients. Default: ``0.9`` + eps (float, optional): The term :math:`\epsilon` added to the denominator to improve + numerical stability. Ddefault: `1e-8` + """ + + def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): + super().__init__() + + self.learning_rate = learning_rate + self.rho = rho + self.eps = eps + if self.rho < 0.0: + raise ValueError( + f"AdaDelta rho should be >=0, {self.rho} was provided instead" + ) + if self.eps < 0.0: + raise ValueError( + f"AdaDelta epsilon should be >0, {self.eps} was provided instead" + ) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the AdaDelta parameter update and stores :math:`v` and + :math:`u` in the optimizer state.""" + lr = self.learning_rate + rho = self.rho + eps = self.eps + + v = state.get("v", mx.zeros_like(gradient)) + u = state.get("s", mx.zeros_like(gradient)) + + v = rho * v + (1 - rho) * mx.square(gradient) + d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient + u = rho * u + (1 - rho) * mx.square(d) + + state["v"] = v + state["u"] = u + + return parameter - lr * d + + class Adam(Optimizer): r"""Implementation of the Adam optimizer [1]. Our Adam implementation follows the original paper and omits the bias correction in the first and second moment estimates. In detail, + [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic + optimization. ICLR 2015. + .. math:: m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} - [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic - optimization. ICLR 2015. + Args: + learning_rate (float): The learning rate :math:`\lambda`. + betas (Tuple[float, float], optional): The coefficients + :math:`(\beta_1, \beta_2)` used for computing running averages of the + gradient and its square. Default: ``(0.9, 0.999)`` + eps (float, optional): The term :math:`\epsilon` added to the + denominator to improve numerical stability. Default: ``1e-8`` """ def __init__( @@ -188,8 +349,11 @@ class AdamW(Adam): r"""Implementation of the AdamW optimizer [1]. Following the above convention, in contrast with [1], we do not use bias - correction in the first and second moments for AdamW. We update the weights - with a weight_decay (λ) value: + correction in the first and second moments for AdamW. We update the weights + with a weight_decay (:math:`\lambda`) value: + + [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay + regularization. ICLR 2019. .. math:: @@ -197,8 +361,15 @@ class AdamW(Adam): v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t) - [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay - regularization. ICLR 2019. + Args: + learning_rate (float): The learning rate :math:`\alpha`. + betas (Tuple[float, float], optional): The coefficients + :math:`(\beta_1, \beta_2)` used for computing running averages of the + gradient and its square. Default: ``(0.9, 0.999)`` + eps (float, optional): The term :math:`\epsilon` added to the + denominator to improve numerical stability. Default: ``1e-8`` + weight_decay (float, optional): The weight decay :math:`\lambda`. + Default: ``0``. """ def __init__( @@ -223,45 +394,51 @@ class AdamW(Adam): ) -class Adagrad(Optimizer): - r"""Implementation of the Adagrad optimizer [1]. +class Adamax(Adam): + r"""Implementation of the Adamax optimizer. It is a variant of Adam based + on the infinity norm [1]. - Our Adagrad implementation follows the original paper. In detail, + Our Adam implementation follows the original paper and omits the bias + correction in the first and second moment estimates. In detail, + + [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic + optimization. ICLR 2015. .. math:: - v_{t+1} &= v_t + g_t^2 \\ - w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1} + \epsilon}} + m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ + v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\ + w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon} - [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods - for online learning and stochastic optimization. JMLR 2011. + Args: + learning_rate (float): The learning rate :math:`\lambda`. + betas (Tuple[float, float], optional): The coefficients + :math:`(\beta_1, \beta_2)` used for computing running averages of the + gradient and its square. Default: ``(0.9, 0.999)`` + eps (float, optional): The term :math:`\epsilon` added to the + denominator to improve numerical stability. Default: ``1e-8`` """ - def __init__(self, learning_rate: float, eps: float = 1e-8): - super().__init__() - - self.learning_rate = learning_rate - self.eps = eps - - if self.learning_rate < 0.0: - raise ValueError( - f"Adagrad learning rate should be >=0, {self.learning_rate} was provided instead" - ) - if self.eps < 0.0: - raise ValueError( - f"Adagrad epsilon should be >0, {self.eps} was provided instead" - ) + def __init__( + self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 + ): + super().__init__(learning_rate, betas, eps) def apply_single( self, gradient: mx.array, parameter: mx.array, state: OptimizerState ): - """Performs the Adagrad parameter update and stores :math:`v` in the - optimizer state.""" + """Performs the Adamax parameter update and stores :math:`v` and + :math:`m` in the optimizer state.""" lr = self.learning_rate + b1, b2 = self.betas eps = self.eps + m = state.get("m", mx.zeros_like(gradient)) v = state.get("v", mx.zeros_like(gradient)) - v = v + mx.square(gradient) + + m = b1 * m + (1 - b1) * gradient + v = mx.maximum(b2 * v, mx.abs(gradient)) + state["m"] = m state["v"] = v - return parameter - lr * gradient / (mx.sqrt(v) + eps) + return parameter - lr * m / (v + eps)