From 0eb035b4b1922a8b3c5f76092a42b83447851a93 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 10 Jul 2025 11:14:42 -0700 Subject: [PATCH] Fix type promotion in Adam with bias correction (#2350) --- python/mlx/optimizers/optimizers.py | 6 ++++-- python/tests/test_optimizers.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 09857dd0a..26b732ebd 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -526,8 +526,10 @@ class Adam(Optimizer): state["v"] = v if bias_correction: - numerator = lr / (1 - b1**step) * m - denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps + c1 = (lr / (1 - b1**step)).astype(gradient.dtype) + c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype) + numerator = c1 * m + denominator = mx.sqrt(v) * c2 + eps return parameter - numerator / denominator else: return parameter - lr * m / (mx.sqrt(v) + eps) diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index e07fc8456..8f9e33679 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase): ) ) + # Test for correct gradient type propagation + params = tree_map(lambda x: x.astype(mx.float16), params) + grads = tree_map(lambda x: x.astype(mx.float16), grads) + optim = opt.Adam(1e-2, bias_correction=True) + new_params = optim.apply_gradients(grads, params) + self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params)) + @unittest.skipIf(not has_torch, "requires Torch") def test_adamw_matches_pytorch(self): mx.random.seed(0)