Fix for AdaDelta (#603)

- state was being read from parameter "s"
- but being stored in parameter "u"
This commit is contained in:
David Koski 2024-02-01 09:56:27 -08:00 committed by GitHub
parent ba8d6bf365
commit 601c6d6aa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -284,7 +284,7 @@ class AdaDelta(Optimizer):
eps = self.eps
v = state.get("v", mx.zeros_like(gradient))
u = state.get("s", mx.zeros_like(gradient))
u = state.get("u", mx.zeros_like(gradient))
v = rho * v + (1 - rho) * mx.square(gradient)
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient