Power VJP fix for 0 (#505)

This commit is contained in:
Awni Hannun 2024-01-20 01:17:40 -08:00 committed by GitHub
parent 6bf779e72b
commit b207c2c86b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 1 deletions

View File

@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(multiply(
outputs[0], divide(primals[1], primals[0], stream()), stream()));
power(
primals[0],
subtract(primals[1], array(1, primals[0].dtype()), stream()),
stream()),
primals[1],
stream()));
} else {
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
}

View File

@ -380,6 +380,19 @@ class TestAutograd(mlx_tests.MLXTestCase):
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
self.assertEqual(out.dtype, t)
def test_power_grad(self):
x = mx.array(0.0)
g = mx.grad(lambda x: x**2)(x)
self.assertEqual(g.item(), 0.0)
x = mx.array(0.0)
g = mx.grad(lambda x: x**1.5)(x)
self.assertEqual(g.item(), 0.0)
x = mx.array(2.0)
g = mx.grad(lambda x: x**2)(x)
self.assertAlmostEqual(g.item(), 4.0)
if __name__ == "__main__":
unittest.main()