stable cumprod grad at 0 (#1167)

This commit is contained in:
Awni Hannun
2024-05-31 12:28:42 -07:00
committed by GitHub
parent 76b6cece46
commit fd1c08137b
2 changed files with 119 additions and 6 deletions

View File

@@ -423,6 +423,79 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
self.assertEqual(grad.item(), 1.0)
def test_cumprod_grad(self):
def fun(y):
return mx.cumprod(y).sum()
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
self.assertTrue(mx.allclose(out, expected))
def fun(y):
return mx.cumprod(y, inclusive=False).sum()
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
self.assertTrue(mx.allclose(out, expected))
def fun(y):
return mx.cumprod(y, inclusive=False, reverse=True).sum()
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))
def fun(y):
return mx.cumprod(y, reverse=True).sum()
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0])
self.assertTrue(mx.allclose(out, expected))
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
out = mx.grad(fun)(y)
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()