mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
LogCumSumExp (#2069)
This commit is contained in:
@@ -3478,6 +3478,45 @@ std::vector<array> Scan::vjp(
|
||||
|
||||
if (reduce_type_ == Scan::Sum) {
|
||||
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
|
||||
} else if (reduce_type_ == Scan::LogAddExp) {
|
||||
// Ref:
|
||||
// https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
|
||||
|
||||
auto x = primals[0];
|
||||
auto grad = cotangents[0];
|
||||
auto results = outputs[0];
|
||||
|
||||
auto zero = zeros({1}, grad.dtype(), stream());
|
||||
auto grad_min = array(finfo(grad.dtype()).min, grad.dtype());
|
||||
|
||||
// Split the incoming gradient into positive and negative part
|
||||
// in order to take logs. This is required for stable results.
|
||||
auto log_abs_grad = log(abs(grad, stream()), stream());
|
||||
auto log_grad_positive =
|
||||
where(greater(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||
auto log_grad_negative =
|
||||
where(less(grad, zero, stream()), log_abs_grad, grad_min, stream());
|
||||
|
||||
auto output_pos = exp(
|
||||
add(logcumsumexp(
|
||||
subtract(log_grad_positive, results, stream()),
|
||||
axis_,
|
||||
!reverse_,
|
||||
inclusive_,
|
||||
stream()),
|
||||
x,
|
||||
stream()));
|
||||
auto output_neg = exp(
|
||||
add(logcumsumexp(
|
||||
subtract(log_grad_negative, results, stream()),
|
||||
axis_,
|
||||
!reverse_,
|
||||
inclusive_,
|
||||
stream()),
|
||||
x,
|
||||
stream()));
|
||||
|
||||
return {subtract(output_pos, output_neg, stream())};
|
||||
} else if (reduce_type_ == Scan::Prod) {
|
||||
auto in = primals[0];
|
||||
// Find the location of the first 0 and set it to 1:
|
||||
|
||||
Reference in New Issue
Block a user