mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-06 01:11:12 +08:00
Fix for AdaDelta (#603)
- state was being read from parameter "s" - but being stored in parameter "u"
This commit is contained in:
parent
ba8d6bf365
commit
601c6d6aa8
@ -284,7 +284,7 @@ class AdaDelta(Optimizer):
|
|||||||
eps = self.eps
|
eps = self.eps
|
||||||
|
|
||||||
v = state.get("v", mx.zeros_like(gradient))
|
v = state.get("v", mx.zeros_like(gradient))
|
||||||
u = state.get("s", mx.zeros_like(gradient))
|
u = state.get("u", mx.zeros_like(gradient))
|
||||||
|
|
||||||
v = rho * v + (1 - rho) * mx.square(gradient)
|
v = rho * v + (1 - rho) * mx.square(gradient)
|
||||||
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
||||||
|
Loading…
Reference in New Issue
Block a user