mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
stable cumprod grad at 0 (#1167)
This commit is contained in:
parent
76b6cece46
commit
fd1c08137b
@ -2748,12 +2748,52 @@ std::vector<array> Scan::vjp(
|
|||||||
if (reduce_type_ == Scan::Sum) {
|
if (reduce_type_ == Scan::Sum) {
|
||||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||||
} else if (reduce_type_ == Scan::Prod) {
|
} else if (reduce_type_ == Scan::Prod) {
|
||||||
// TODO: Make it numerically stable when we introduce where()
|
auto in = primals[0];
|
||||||
auto prod = outputs[0];
|
// Find the location of the first 0 and set it to 1:
|
||||||
auto partial_grads = multiply(prod, cotangents[0], stream());
|
// - A: Exclusive cumprod
|
||||||
auto accum_grads =
|
// - B: Inclusive cumprod
|
||||||
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
|
// - Find the location that is 0 in A and not zero B
|
||||||
return {divide(accum_grads, primals[0], stream())};
|
// Compute the gradient by:
|
||||||
|
// - Compute the regular gradient for everything before the first zero
|
||||||
|
// - Set the first zero to 1 and redo the computation, use this for the
|
||||||
|
// gradient of the first zero
|
||||||
|
// - Everything after the first zero has a gradient of 0
|
||||||
|
|
||||||
|
// Get inclusive and exclusive cum prods
|
||||||
|
auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream());
|
||||||
|
auto cprod_inclusive = outputs[0];
|
||||||
|
if (!inclusive_) {
|
||||||
|
std::swap(cprod_exclusive, cprod_inclusive);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the mask for the first zero
|
||||||
|
auto z = array(0, in.dtype());
|
||||||
|
auto eq_zero = equal(cprod_inclusive, z, stream());
|
||||||
|
auto first_zero =
|
||||||
|
logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream());
|
||||||
|
|
||||||
|
auto to_partial_grad = [this, &cotangents](const array& arr) {
|
||||||
|
return cumsum(
|
||||||
|
multiply(arr, cotangents[0], stream()),
|
||||||
|
axis_,
|
||||||
|
!reverse_,
|
||||||
|
inclusive_,
|
||||||
|
stream());
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cprod_with_one = cumprod(
|
||||||
|
where(first_zero, array(1, in.dtype()), in, stream()),
|
||||||
|
axis_,
|
||||||
|
reverse_,
|
||||||
|
inclusive_,
|
||||||
|
stream());
|
||||||
|
auto grad_with_one = to_partial_grad(cprod_with_one);
|
||||||
|
auto grad = divide(to_partial_grad(outputs[0]), in, stream());
|
||||||
|
return {where(
|
||||||
|
first_zero,
|
||||||
|
grad_with_one,
|
||||||
|
where(eq_zero, z, grad, stream()),
|
||||||
|
stream())};
|
||||||
} else {
|
} else {
|
||||||
// Can probably be implemented by equals and then cummax to make the mask
|
// Can probably be implemented by equals and then cummax to make the mask
|
||||||
throw std::runtime_error("VJP is not implemented for cumulative min/max");
|
throw std::runtime_error("VJP is not implemented for cumulative min/max");
|
||||||
|
@ -423,6 +423,79 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||||
self.assertEqual(grad.item(), 1.0)
|
self.assertEqual(grad.item(), 1.0)
|
||||||
|
|
||||||
|
def test_cumprod_grad(self):
|
||||||
|
def fun(y):
|
||||||
|
return mx.cumprod(y).sum()
|
||||||
|
|
||||||
|
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def fun(y):
|
||||||
|
return mx.cumprod(y, inclusive=False).sum()
|
||||||
|
|
||||||
|
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def fun(y):
|
||||||
|
return mx.cumprod(y, inclusive=False, reverse=True).sum()
|
||||||
|
|
||||||
|
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
def fun(y):
|
||||||
|
return mx.cumprod(y, reverse=True).sum()
|
||||||
|
|
||||||
|
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
|
||||||
|
out = mx.grad(fun)(y)
|
||||||
|
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||||
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user