Fix some complex vjps (#2178)

This commit is contained in:
Angelos Katharopoulos
2025-05-14 23:37:12 -07:00
committed by GitHub
parent 130df35e1b
commit cf6c939e86
4 changed files with 166 additions and 42 deletions

View File

@@ -1488,14 +1488,16 @@ std::vector<array> Divide::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
array denominator_bar = conjugate(primals[1], stream());
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(divide(cotangents[0], primals[1], stream()));
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
} else {
vjps.push_back(negative(
divide(
multiply(cotangents[0], primals[0], stream()),
square(primals[1], stream()),
multiply(
cotangents[0], conjugate(primals[0], stream()), stream()),
square(denominator_bar, stream()),
stream()),
stream()));
}
@@ -1950,30 +1952,74 @@ std::vector<array> FFT::vjp(
assert(argnums.size() == 1);
auto& in = primals[0];
std::vector<int> axes(axes_.begin(), axes_.end());
// TODO: Add it as an option to do an unnormalized or scaled fft so that this
// isn't part of the graph.
double n_elements = 1;
for (auto ax : axes) {
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
}
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
mask_shape[axes_.back()] -= 2;
auto mask = full(mask_shape, 2.0f, stream());
auto pad_shape = out.shape();
pad_shape[axes_.back()] = 1;
auto pad = full(pad_shape, 1.0f, stream());
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets doubled.
int N = in.shape(axes_.back());
bool odd = cotangents[0].shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1 / n_elements, in.dtype());
array two(2 / n_elements, in.dtype());
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
two,
one,
stream());
return {
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
} else if (real_) {
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape()[ax]);
n.push_back(in.shape(ax));
}
return {astype(
fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets halved.
int N = cotangents[0].shape(axes_.back());
bool odd = in.shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
half,
one,
stream());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
} else if (inverse_) {
return {fft::ifftn(cotangents[0], axes, stream())};
return {multiply(
fft::fftn(cotangents[0], axes, stream()),
array(1 / n_elements, complex64),
stream())};
} else {
return {fft::fftn(cotangents[0], axes, stream())};
return {multiply(
fft::ifftn(cotangents[0], axes, stream()),
array(n_elements, complex64),
stream())};
}
}
@@ -2776,7 +2822,8 @@ std::vector<array> Multiply::vjp(
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
vjps.push_back(multiply(
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
}
return vjps;
}