mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
Fix type promotion in Adam with bias correction (#2350)
This commit is contained in:
parent
afb9817599
commit
0eb035b4b1
@ -526,8 +526,10 @@ class Adam(Optimizer):
|
|||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
if bias_correction:
|
if bias_correction:
|
||||||
numerator = lr / (1 - b1**step) * m
|
c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
|
||||||
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
|
c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
|
||||||
|
numerator = c1 * m
|
||||||
|
denominator = mx.sqrt(v) * c2 + eps
|
||||||
return parameter - numerator / denominator
|
return parameter - numerator / denominator
|
||||||
else:
|
else:
|
||||||
return parameter - lr * m / (mx.sqrt(v) + eps)
|
return parameter - lr * m / (mx.sqrt(v) + eps)
|
||||||
|
@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test for correct gradient type propagation
|
||||||
|
params = tree_map(lambda x: x.astype(mx.float16), params)
|
||||||
|
grads = tree_map(lambda x: x.astype(mx.float16), grads)
|
||||||
|
optim = opt.Adam(1e-2, bias_correction=True)
|
||||||
|
new_params = optim.apply_gradients(grads, params)
|
||||||
|
self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
|
||||||
|
|
||||||
@unittest.skipIf(not has_torch, "requires Torch")
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
def test_adamw_matches_pytorch(self):
|
def test_adamw_matches_pytorch(self):
|
||||||
mx.random.seed(0)
|
mx.random.seed(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user