Adds Nesterov momentum to SGD (#87)

This commit is contained in:
Abe Leininger 2023-12-09 02:23:36 -05:00 committed by GitHub
parent 08d51bf232
commit 430bfb4944
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,19 +82,36 @@ class SGD(Optimizer):
.. math::
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
v_{t+1} &= \mu v_t + g_t \\
w_{t+1} &= w_t - \lambda v_{t+1}
Args:
learning_rate (float): The learning :math:`\lambda` for the update
momentum (float): The momentum strength :math:`\mu`
momentum (float, optional): The momentum strength :math:`\mu` (default: 0)
weight_decay (float, optional): The weight decay (L2 penalty) (default: 0)
dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0)
nesterov (bool, optional): Enables Nesterov momentum (default: False)
"""
def __init__(self, learning_rate: float, momentum: float = 0.0):
def __init__(
self,
learning_rate: float,
momentum: float = 0.0,
weight_decay: float = 0.0,
dampening: float = 0.0,
nesterov: bool = False,
):
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening."
)
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
@ -105,9 +122,22 @@ class SGD(Optimizer):
return parameter - self.learning_rate * gradient
v = state.get("v", mx.zeros_like(gradient))
v = self.momentum * v + (1 - self.momentum) * gradient
if self.weight_decay != 0:
gradient += self.weight_decay * parameter
v = self.momentum * v
if self.dampening > 0:
v += (1 - self.dampening) * gradient
else:
v += gradient
if self.nesterov:
update = gradient + self.momentum * v
else:
update = v
state["v"] = v
return parameter - self.learning_rate * v
return parameter - self.learning_rate * update
class Adam(Optimizer):
@ -184,7 +214,7 @@ class AdamW(Adam):
def apply_single(
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
):
"""Performs the AdamW parameter update by modifying the parameters
"""Performs the AdamW parameter update by modifying the parameters
passed into Adam.
"""