diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py index c47234bd5..b659ec5cf 100644 --- a/python/mlx/optimizers.py +++ b/python/mlx/optimizers.py @@ -284,7 +284,7 @@ class AdaDelta(Optimizer): eps = self.eps v = state.get("v", mx.zeros_like(gradient)) - u = state.get("s", mx.zeros_like(gradient)) + u = state.get("u", mx.zeros_like(gradient)) v = rho * v + (1 - rho) * mx.square(gradient) d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient