mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Power VJP fix for 0 (#505)
This commit is contained in:
@@ -380,6 +380,19 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def test_power_grad(self):
|
||||
x = mx.array(0.0)
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertEqual(g.item(), 0.0)
|
||||
|
||||
x = mx.array(0.0)
|
||||
g = mx.grad(lambda x: x**1.5)(x)
|
||||
self.assertEqual(g.item(), 0.0)
|
||||
|
||||
x = mx.array(2.0)
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertAlmostEqual(g.item(), 4.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user