From 143e2690d54dad348a3e6ae0e5bdb0b98547cf75 Mon Sep 17 00:00:00 2001 From: Jacket <44538064+PRESIDENT810@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:50:46 -0600 Subject: [PATCH] Fix SGD implementation (#473) --- python/mlx/optimizers.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index 8b9965e78..1541b41cc 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -118,18 +118,21 @@ class SGD(Optimizer): ): """Performs the SGD parameter update and stores :math:`v` in the optimizer state.""" - if self.momentum <= 0: - return parameter - self.learning_rate * gradient - - v = state.get("v", mx.zeros_like(gradient)) if self.weight_decay != 0: gradient += self.weight_decay * parameter - v = self.momentum * v + if self.momentum <= 0: + return parameter - self.learning_rate * gradient + if self.dampening > 0: + v = ( + state.get("v", (self.dampening / self.momentum) * gradient) + * self.momentum + ) v += (1 - self.dampening) * gradient else: + v = state.get("v", mx.zeros_like(gradient)) * self.momentum v += gradient if self.nesterov: