Fix type promotion in Adam w bias correction

This commit is contained in:
Angelos Katharopoulos
2025-07-09 18:08:36 -07:00
parent fb4e8b896b
commit 067950ce00
2 changed files with 11 additions and 2 deletions

View File

@@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase):
)
)
# Test for correct gradient type propagation
params = tree_map(lambda x: x.astype(mx.float16), params)
grads = tree_map(lambda x: x.astype(mx.float16), grads)
optim = opt.Adam(1e-2, bias_correction=True)
new_params = optim.apply_gradients(grads, params)
self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
@unittest.skipIf(not has_torch, "requires Torch")
def test_adamw_matches_pytorch(self):
mx.random.seed(0)