diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index 9a77328d3..fd5d9c732 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -82,19 +82,36 @@ class SGD(Optimizer): .. math:: - v_{t+1} &= \mu v_t + (1 - \mu) g_t \\ + v_{t+1} &= \mu v_t + g_t \\ w_{t+1} &= w_t - \lambda v_{t+1} Args: learning_rate (float): The learning :math:`\lambda` for the update - momentum (float): The momentum strength :math:`\mu` + momentum (float, optional): The momentum strength :math:`\mu` (default: 0) + weight_decay (float, optional): The weight decay (L2 penalty) (default: 0) + dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0) + nesterov (bool, optional): Enables Nesterov momentum (default: False) """ - def __init__(self, learning_rate: float, momentum: float = 0.0): + def __init__( + self, + learning_rate: float, + momentum: float = 0.0, + weight_decay: float = 0.0, + dampening: float = 0.0, + nesterov: bool = False, + ): + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening." + ) super().__init__() self.learning_rate = learning_rate self.momentum = momentum + self.weight_decay = weight_decay + self.dampening = dampening + self.nesterov = nesterov def apply_single( self, gradient: mx.array, parameter: mx.array, state: OptimizerState @@ -105,9 +122,22 @@ class SGD(Optimizer): return parameter - self.learning_rate * gradient v = state.get("v", mx.zeros_like(gradient)) - v = self.momentum * v + (1 - self.momentum) * gradient + + if self.weight_decay != 0: + gradient += self.weight_decay * parameter + + v = self.momentum * v + if self.dampening > 0: + v += (1 - self.dampening) * gradient + else: + v += gradient + + if self.nesterov: + update = gradient + self.momentum * v + else: + update = v state["v"] = v - return parameter - self.learning_rate * v + return parameter - self.learning_rate * update class Adam(Optimizer): @@ -184,7 +214,7 @@ class AdamW(Adam): def apply_single( self, gradient: mx.array, parameter: mx.array, state: OptimizerState ): - """Performs the AdamW parameter update by modifying the parameters + """Performs the AdamW parameter update by modifying the parameters passed into Adam. """