stable cumprod grad at 0 (#1167)

This commit is contained in:
Awni Hannun
2024-05-31 12:28:42 -07:00
committed by GitHub
parent 76b6cece46
commit fd1c08137b
2 changed files with 119 additions and 6 deletions

View File

@@ -2748,12 +2748,52 @@ std::vector<array> Scan::vjp(
if (reduce_type_ == Scan::Sum) {
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::Prod) {
// TODO: Make it numerically stable when we introduce where()
auto prod = outputs[0];
auto partial_grads = multiply(prod, cotangents[0], stream());
auto accum_grads =
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
return {divide(accum_grads, primals[0], stream())};
auto in = primals[0];
// Find the location of the first 0 and set it to 1:
// - A: Exclusive cumprod
// - B: Inclusive cumprod
// - Find the location that is 0 in A and not zero B
// Compute the gradient by:
// - Compute the regular gradient for everything before the first zero
// - Set the first zero to 1 and redo the computation, use this for the
// gradient of the first zero
// - Everything after the first zero has a gradient of 0
// Get inclusive and exclusive cum prods
auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream());
auto cprod_inclusive = outputs[0];
if (!inclusive_) {
std::swap(cprod_exclusive, cprod_inclusive);
}
// Make the mask for the first zero
auto z = array(0, in.dtype());
auto eq_zero = equal(cprod_inclusive, z, stream());
auto first_zero =
logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream());
auto to_partial_grad = [this, &cotangents](const array& arr) {
return cumsum(
multiply(arr, cotangents[0], stream()),
axis_,
!reverse_,
inclusive_,
stream());
};
auto cprod_with_one = cumprod(
where(first_zero, array(1, in.dtype()), in, stream()),
axis_,
reverse_,
inclusive_,
stream());
auto grad_with_one = to_partial_grad(cprod_with_one);
auto grad = divide(to_partial_grad(outputs[0]), in, stream());
return {where(
first_zero,
grad_with_one,
where(eq_zero, z, grad, stream()),
stream())};
} else {
// Can probably be implemented by equals and then cummax to make the mask
throw std::runtime_error("VJP is not implemented for cumulative min/max");