Fix type promotion in Adam with bias correction (#2350)

This commit is contained in:
Angelos Katharopoulos 2025-07-10 11:14:42 -07:00 committed by GitHub
parent afb9817599
commit 0eb035b4b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 2 deletions

View File

@ -526,8 +526,10 @@ class Adam(Optimizer):
state["v"] = v state["v"] = v
if bias_correction: if bias_correction:
numerator = lr / (1 - b1**step) * m c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
numerator = c1 * m
denominator = mx.sqrt(v) * c2 + eps
return parameter - numerator / denominator return parameter - numerator / denominator
else: else:
return parameter - lr * m / (mx.sqrt(v) + eps) return parameter - lr * m / (mx.sqrt(v) + eps)

View File

@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase):
) )
) )
# Test for correct gradient type propagation
params = tree_map(lambda x: x.astype(mx.float16), params)
grads = tree_map(lambda x: x.astype(mx.float16), grads)
optim = opt.Adam(1e-2, bias_correction=True)
new_params = optim.apply_gradients(grads, params)
self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
@unittest.skipIf(not has_torch, "requires Torch") @unittest.skipIf(not has_torch, "requires Torch")
def test_adamw_matches_pytorch(self): def test_adamw_matches_pytorch(self):
mx.random.seed(0) mx.random.seed(0)