From fd1c08137bd2a87db747196211786b1f2fb197f0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 May 2024 12:28:42 -0700 Subject: [PATCH] stable cumprod grad at 0 (#1167) --- mlx/primitives.cpp | 52 ++++++++++++++++++++++--- python/tests/test_autograd.py | 73 +++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 6 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 1bbc963c6..21cdcb4e9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2748,12 +2748,52 @@ std::vector Scan::vjp( if (reduce_type_ == Scan::Sum) { return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())}; } else if (reduce_type_ == Scan::Prod) { - // TODO: Make it numerically stable when we introduce where() - auto prod = outputs[0]; - auto partial_grads = multiply(prod, cotangents[0], stream()); - auto accum_grads = - cumsum(partial_grads, axis_, !reverse_, inclusive_, stream()); - return {divide(accum_grads, primals[0], stream())}; + auto in = primals[0]; + // Find the location of the first 0 and set it to 1: + // - A: Exclusive cumprod + // - B: Inclusive cumprod + // - Find the location that is 0 in A and not zero B + // Compute the gradient by: + // - Compute the regular gradient for everything before the first zero + // - Set the first zero to 1 and redo the computation, use this for the + // gradient of the first zero + // - Everything after the first zero has a gradient of 0 + + // Get inclusive and exclusive cum prods + auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream()); + auto cprod_inclusive = outputs[0]; + if (!inclusive_) { + std::swap(cprod_exclusive, cprod_inclusive); + } + + // Make the mask for the first zero + auto z = array(0, in.dtype()); + auto eq_zero = equal(cprod_inclusive, z, stream()); + auto first_zero = + logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream()); + + auto to_partial_grad = [this, &cotangents](const array& arr) { + return cumsum( + multiply(arr, cotangents[0], stream()), + axis_, + !reverse_, + inclusive_, + stream()); + }; + + auto cprod_with_one = cumprod( + where(first_zero, array(1, in.dtype()), in, stream()), + axis_, + reverse_, + inclusive_, + stream()); + auto grad_with_one = to_partial_grad(cprod_with_one); + auto grad = divide(to_partial_grad(outputs[0]), in, stream()); + return {where( + first_zero, + grad_with_one, + where(eq_zero, z, grad, stream()), + stream())}; } else { // Can probably be implemented by equals and then cummax to make the mask throw std::runtime_error("VJP is not implemented for cumulative min/max"); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 28054fb9b..f5e49f402 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -423,6 +423,79 @@ class TestAutograd(mlx_tests.MLXTestCase): grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0)) self.assertEqual(grad.item(), 1.0) + def test_cumprod_grad(self): + def fun(y): + return mx.cumprod(y).sum() + + y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0]) + self.assertTrue(mx.allclose(out, expected)) + + def fun(y): + return mx.cumprod(y, inclusive=False).sum() + + y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0]) + self.assertTrue(mx.allclose(out, expected)) + + def fun(y): + return mx.cumprod(y, inclusive=False, reverse=True).sum() + + y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) + self.assertTrue(mx.allclose(out, expected)) + + def fun(y): + return mx.cumprod(y, reverse=True).sum() + + y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0]) + self.assertTrue(mx.allclose(out, expected)) + + y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0]) + out = mx.grad(fun)(y) + expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": unittest.main()