diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index cc91d7147..3b96a576d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1852,7 +1852,12 @@ std::vector Power::vjp( for (auto arg : argnums) { if (arg == 0) { vjps.push_back(multiply( - outputs[0], divide(primals[1], primals[0], stream()), stream())); + power( + primals[0], + subtract(primals[1], array(1, primals[0].dtype()), stream()), + stream()), + primals[1], + stream())); } else { vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream())); } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index c7edc8b76..78f7346a8 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -380,6 +380,19 @@ class TestAutograd(mlx_tests.MLXTestCase): out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t)) self.assertEqual(out.dtype, t) + def test_power_grad(self): + x = mx.array(0.0) + g = mx.grad(lambda x: x**2)(x) + self.assertEqual(g.item(), 0.0) + + x = mx.array(0.0) + g = mx.grad(lambda x: x**1.5)(x) + self.assertEqual(g.item(), 0.0) + + x = mx.array(2.0) + g = mx.grad(lambda x: x**2)(x) + self.assertAlmostEqual(g.item(), 4.0) + if __name__ == "__main__": unittest.main()