mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
Fix type promotion in Adam with bias correction (#2350)
This commit is contained in:
committed by
GitHub
parent
afb9817599
commit
0eb035b4b1
@@ -526,8 +526,10 @@ class Adam(Optimizer):
|
||||
state["v"] = v
|
||||
|
||||
if bias_correction:
|
||||
numerator = lr / (1 - b1**step) * m
|
||||
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
||||
c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
|
||||
c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
|
||||
numerator = c1 * m
|
||||
denominator = mx.sqrt(v) * c2 + eps
|
||||
return parameter - numerator / denominator
|
||||
else:
|
||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||
|
||||
Reference in New Issue
Block a user