mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Fix type promotion in Adam with bias correction (#2350)
This commit is contained in:

committed by
GitHub

parent
afb9817599
commit
0eb035b4b1
@@ -196,6 +196,13 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
# Test for correct gradient type propagation
|
||||
params = tree_map(lambda x: x.astype(mx.float16), params)
|
||||
grads = tree_map(lambda x: x.astype(mx.float16), grads)
|
||||
optim = opt.Adam(1e-2, bias_correction=True)
|
||||
new_params = optim.apply_gradients(grads, params)
|
||||
self.assertTrue(tree_equal(lambda p: p.dtype == mx.float16, new_params))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_adamw_matches_pytorch(self):
|
||||
mx.random.seed(0)
|
||||
|
Reference in New Issue
Block a user