mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 11:31:21 +08:00
Fix SGD implementation (#473)
This commit is contained in:
parent
375446453e
commit
143e2690d5
@ -118,18 +118,21 @@ class SGD(Optimizer):
|
|||||||
):
|
):
|
||||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||||
optimizer state."""
|
optimizer state."""
|
||||||
if self.momentum <= 0:
|
|
||||||
return parameter - self.learning_rate * gradient
|
|
||||||
|
|
||||||
v = state.get("v", mx.zeros_like(gradient))
|
|
||||||
|
|
||||||
if self.weight_decay != 0:
|
if self.weight_decay != 0:
|
||||||
gradient += self.weight_decay * parameter
|
gradient += self.weight_decay * parameter
|
||||||
|
|
||||||
v = self.momentum * v
|
if self.momentum <= 0:
|
||||||
|
return parameter - self.learning_rate * gradient
|
||||||
|
|
||||||
if self.dampening > 0:
|
if self.dampening > 0:
|
||||||
|
v = (
|
||||||
|
state.get("v", (self.dampening / self.momentum) * gradient)
|
||||||
|
* self.momentum
|
||||||
|
)
|
||||||
v += (1 - self.dampening) * gradient
|
v += (1 - self.dampening) * gradient
|
||||||
else:
|
else:
|
||||||
|
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
|
||||||
v += gradient
|
v += gradient
|
||||||
|
|
||||||
if self.nesterov:
|
if self.nesterov:
|
||||||
|
Loading…
Reference in New Issue
Block a user