mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Adds Nesterov momentum to SGD (#87)
This commit is contained in:
parent
08d51bf232
commit
430bfb4944
@ -82,19 +82,36 @@ class SGD(Optimizer):
|
||||
|
||||
.. math::
|
||||
|
||||
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
|
||||
v_{t+1} &= \mu v_t + g_t \\
|
||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||
|
||||
Args:
|
||||
learning_rate (float): The learning :math:`\lambda` for the update
|
||||
momentum (float): The momentum strength :math:`\mu`
|
||||
momentum (float, optional): The momentum strength :math:`\mu` (default: 0)
|
||||
weight_decay (float, optional): The weight decay (L2 penalty) (default: 0)
|
||||
dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0)
|
||||
nesterov (bool, optional): Enables Nesterov momentum (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate: float, momentum: float = 0.0):
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: float,
|
||||
momentum: float = 0.0,
|
||||
weight_decay: float = 0.0,
|
||||
dampening: float = 0.0,
|
||||
nesterov: bool = False,
|
||||
):
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError(
|
||||
"Nesterov momentum requires a momentum and zero dampening."
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.dampening = dampening
|
||||
self.nesterov = nesterov
|
||||
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
@ -105,9 +122,22 @@ class SGD(Optimizer):
|
||||
return parameter - self.learning_rate * gradient
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
v = self.momentum * v + (1 - self.momentum) * gradient
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient += self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * v
|
||||
if self.dampening > 0:
|
||||
v += (1 - self.dampening) * gradient
|
||||
else:
|
||||
v += gradient
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient + self.momentum * v
|
||||
else:
|
||||
update = v
|
||||
state["v"] = v
|
||||
return parameter - self.learning_rate * v
|
||||
return parameter - self.learning_rate * update
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
@ -184,7 +214,7 @@ class AdamW(Adam):
|
||||
def apply_single(
|
||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||
):
|
||||
"""Performs the AdamW parameter update by modifying the parameters
|
||||
"""Performs the AdamW parameter update by modifying the parameters
|
||||
passed into Adam.
|
||||
"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user