mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Adds Nesterov momentum to SGD (#87)
This commit is contained in:
parent
08d51bf232
commit
430bfb4944
@ -82,19 +82,36 @@ class SGD(Optimizer):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
v_{t+1} &= \mu v_t + (1 - \mu) g_t \\
|
v_{t+1} &= \mu v_t + g_t \\
|
||||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning :math:`\lambda` for the update
|
learning_rate (float): The learning :math:`\lambda` for the update
|
||||||
momentum (float): The momentum strength :math:`\mu`
|
momentum (float, optional): The momentum strength :math:`\mu` (default: 0)
|
||||||
|
weight_decay (float, optional): The weight decay (L2 penalty) (default: 0)
|
||||||
|
dampening (float, optional): Dampening for momentum :math:`\tau` (default: 0)
|
||||||
|
nesterov (bool, optional): Enables Nesterov momentum (default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, momentum: float = 0.0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: float,
|
||||||
|
momentum: float = 0.0,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
dampening: float = 0.0,
|
||||||
|
nesterov: bool = False,
|
||||||
|
):
|
||||||
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
|
raise ValueError(
|
||||||
|
"Nesterov momentum requires a momentum and zero dampening."
|
||||||
|
)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.dampening = dampening
|
||||||
|
self.nesterov = nesterov
|
||||||
|
|
||||||
def apply_single(
|
def apply_single(
|
||||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
@ -105,9 +122,22 @@ class SGD(Optimizer):
|
|||||||
return parameter - self.learning_rate * gradient
|
return parameter - self.learning_rate * gradient
|
||||||
|
|
||||||
v = state.get("v", mx.zeros_like(gradient))
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
v = self.momentum * v + (1 - self.momentum) * gradient
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
gradient += self.weight_decay * parameter
|
||||||
|
|
||||||
|
v = self.momentum * v
|
||||||
|
if self.dampening > 0:
|
||||||
|
v += (1 - self.dampening) * gradient
|
||||||
|
else:
|
||||||
|
v += gradient
|
||||||
|
|
||||||
|
if self.nesterov:
|
||||||
|
update = gradient + self.momentum * v
|
||||||
|
else:
|
||||||
|
update = v
|
||||||
state["v"] = v
|
state["v"] = v
|
||||||
return parameter - self.learning_rate * v
|
return parameter - self.learning_rate * update
|
||||||
|
|
||||||
|
|
||||||
class Adam(Optimizer):
|
class Adam(Optimizer):
|
||||||
@ -184,7 +214,7 @@ class AdamW(Adam):
|
|||||||
def apply_single(
|
def apply_single(
|
||||||
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
||||||
):
|
):
|
||||||
"""Performs the AdamW parameter update by modifying the parameters
|
"""Performs the AdamW parameter update by modifying the parameters
|
||||||
passed into Adam.
|
passed into Adam.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user