Fix SGD implementation (#473)

This commit is contained in:
Jacket 2024-01-30 17:50:46 -06:00 committed by GitHub
parent 375446453e
commit 143e2690d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -118,18 +118,21 @@ class SGD(Optimizer):
):
"""Performs the SGD parameter update and stores :math:`v` in the
optimizer state."""
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
v = state.get("v", mx.zeros_like(gradient))
if self.weight_decay != 0:
gradient += self.weight_decay * parameter
v = self.momentum * v
if self.momentum <= 0:
return parameter - self.learning_rate * gradient
if self.dampening > 0:
v = (
state.get("v", (self.dampening / self.momentum) * gradient)
* self.momentum
)
v += (1 - self.dampening) * gradient
else:
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
v += gradient
if self.nesterov: