LogCumSumExp (#2069)

This commit is contained in:
Yury Popov
2025-04-13 11:27:29 +03:00
committed by GitHub
parent 7275ac7523
commit e9e268336b
15 changed files with 209 additions and 3 deletions

View File

@@ -3478,6 +3478,45 @@ std::vector<array> Scan::vjp(
if (reduce_type_ == Scan::Sum) {
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::LogAddExp) {
// Ref:
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
auto x = primals[0];
auto grad = cotangents[0];
auto results = outputs[0];
auto zero = zeros({1}, grad.dtype(), stream());
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
// Split the incoming gradient into positive and negative part
// in order to take logs. This is required for stable results.
auto log_abs_grad = log(abs(grad, stream()), stream());
auto log_grad_positive =
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto log_grad_negative =
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
auto output_pos = exp(
add(logcumsumexp(
subtract(log_grad_positive, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
auto output_neg = exp(
add(logcumsumexp(
subtract(log_grad_negative, results, stream()),
axis_,
!reverse_,
inclusive_,
stream()),
x,
stream()));
return {subtract(output_pos, output_neg, stream())};
} else if (reduce_type_ == Scan::Prod) {
auto in = primals[0];
// Find the location of the first 0 and set it to 1: