mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 08:11:13 +08:00
Fix for AdaDelta (#603)
- state was being read from parameter "s" - but being stored in parameter "u"
This commit is contained in:
parent
ba8d6bf365
commit
601c6d6aa8
@ -284,7 +284,7 @@ class AdaDelta(Optimizer):
|
||||
eps = self.eps
|
||||
|
||||
v = state.get("v", mx.zeros_like(gradient))
|
||||
u = state.get("s", mx.zeros_like(gradient))
|
||||
u = state.get("u", mx.zeros_like(gradient))
|
||||
|
||||
v = rho * v + (1 - rho) * mx.square(gradient)
|
||||
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
||||
|
Loading…
Reference in New Issue
Block a user