mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Power VJP fix for 0 (#505)
This commit is contained in:
parent
6bf779e72b
commit
b207c2c86b
@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(multiply(
|
||||
outputs[0], divide(primals[1], primals[0], stream()), stream()));
|
||||
power(
|
||||
primals[0],
|
||||
subtract(primals[1], array(1, primals[0].dtype()), stream()),
|
||||
stream()),
|
||||
primals[1],
|
||||
stream()));
|
||||
} else {
|
||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||
}
|
||||
|
@ -380,6 +380,19 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def test_power_grad(self):
|
||||
x = mx.array(0.0)
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertEqual(g.item(), 0.0)
|
||||
|
||||
x = mx.array(0.0)
|
||||
g = mx.grad(lambda x: x**1.5)(x)
|
||||
self.assertEqual(g.item(), 0.0)
|
||||
|
||||
x = mx.array(2.0)
|
||||
g = mx.grad(lambda x: x**2)(x)
|
||||
self.assertAlmostEqual(g.item(), 4.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user