mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
Fix SGD implementation (#473)
This commit is contained in:
parent
375446453e
commit
143e2690d5
@ -118,18 +118,21 @@ class SGD(Optimizer):
|
||||
):
|
||||
"""Performs the SGD parameter update and stores :math:`v` in the
|
||||
optimizer state."""
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient += self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * v
|
||||
if self.momentum <= 0:
|
||||
return parameter - self.learning_rate * gradient
|
||||
|
||||
if self.dampening > 0:
|
||||
v = (
|
||||
state.get("v", (self.dampening / self.momentum) * gradient)
|
||||
* self.momentum
|
||||
)
|
||||
v += (1 - self.dampening) * gradient
|
||||
else:
|
||||
v = state.get("v", mx.zeros_like(gradient)) * self.momentum
|
||||
v += gradient
|
||||
|
||||
if self.nesterov:
|
||||
|
Loading…
Reference in New Issue
Block a user