This commit is contained in:
Angelos Katharopoulos
2025-05-14 21:53:01 -07:00
parent f93cda7a1c
commit 2acf2e003e

View File

@@ -1488,16 +1488,16 @@ std::vector<array> Divide::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
array denominator_bar = conjugate(primals[1], stream());
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(
divide(cotangents[0], conjugate(primals[1], stream()), stream()));
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
} else {
vjps.push_back(negative(
divide(
multiply(
cotangents[0], conjugate(primals[0], stream()), stream()),
square(conjugate(primals[1], stream()), stream()),
square(denominator_bar, stream()),
stream()),
stream()));
}
@@ -1972,8 +1972,14 @@ std::vector<array> FFT::vjp(
array last(N - 1 + odd, indices.dtype());
array one(1 / n_elements, in.dtype());
array two(2 / n_elements, in.dtype());
array mask =
where((first < indices) & (indices < last), two, one, stream());
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
two,
one,
stream());
return {
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
} else if (real_) {
@@ -1992,8 +1998,14 @@ std::vector<array> FFT::vjp(
array last(N - 1 + odd, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask =
where((first < indices) & (indices < last), half, one, stream());
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
half,
one,
stream());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),