mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
stable cumprod grad at 0 (#1167)
This commit is contained in:
@@ -423,6 +423,79 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(grad.item(), 1.0)
|
||||
|
||||
def test_cumprod_grad(self):
|
||||
def fun(y):
|
||||
return mx.cumprod(y).sum()
|
||||
|
||||
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def fun(y):
|
||||
return mx.cumprod(y, inclusive=False).sum()
|
||||
|
||||
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def fun(y):
|
||||
return mx.cumprod(y, inclusive=False, reverse=True).sum()
|
||||
|
||||
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def fun(y):
|
||||
return mx.cumprod(y, reverse=True).sum()
|
||||
|
||||
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||
out = mx.grad(fun)(y)
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user