mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Power VJP fix for 0 (#505)
This commit is contained in:
parent
6bf779e72b
commit
b207c2c86b
@ -1852,7 +1852,12 @@ std::vector<array> Power::vjp(
|
|||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
if (arg == 0) {
|
if (arg == 0) {
|
||||||
vjps.push_back(multiply(
|
vjps.push_back(multiply(
|
||||||
outputs[0], divide(primals[1], primals[0], stream()), stream()));
|
power(
|
||||||
|
primals[0],
|
||||||
|
subtract(primals[1], array(1, primals[0].dtype()), stream()),
|
||||||
|
stream()),
|
||||||
|
primals[1],
|
||||||
|
stream()));
|
||||||
} else {
|
} else {
|
||||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||||
}
|
}
|
||||||
|
@ -380,6 +380,19 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
||||||
self.assertEqual(out.dtype, t)
|
self.assertEqual(out.dtype, t)
|
||||||
|
|
||||||
|
def test_power_grad(self):
|
||||||
|
x = mx.array(0.0)
|
||||||
|
g = mx.grad(lambda x: x**2)(x)
|
||||||
|
self.assertEqual(g.item(), 0.0)
|
||||||
|
|
||||||
|
x = mx.array(0.0)
|
||||||
|
g = mx.grad(lambda x: x**1.5)(x)
|
||||||
|
self.assertEqual(g.item(), 0.0)
|
||||||
|
|
||||||
|
x = mx.array(2.0)
|
||||||
|
g = mx.grad(lambda x: x**2)(x)
|
||||||
|
self.assertAlmostEqual(g.item(), 4.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user