Fix the last fft vjp

This commit is contained in:
Angelos Katharopoulos 2025-05-12 17:13:12 -07:00
parent 043496bc9b
commit e1c65e1381

View File

@ -1957,47 +1957,46 @@ std::vector<array> FFT::vjp(
// isn't part of the graph.
double n_elements = 1;
for (auto ax : axes) {
n_elements *= primals[0].shape(ax);
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
}
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
mask_shape[axes_.back()] -= 2;
auto mask = full(mask_shape, 2.0f, stream());
auto pad_shape = out.shape();
pad_shape[axes_.back()] = 1;
auto pad = full(pad_shape, 1.0f, stream());
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
} else if (real_) {
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape(ax));
}
if (real_) {
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets halved.
// Everything except the DC and nyquist frequencies gets halved or doubled.
int N =
inverse_ ? in.shape(axes_.back()) : cotangents[0].shape(axes_.back());
Shape c(in.ndim(), 1);
c[axes_.back()] = cotangents[0].shape(axes_.back());
array indices = reshape(
arange(cotangents[0].shape(axes_.back()), stream()),
std::move(c),
stream());
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(cotangents[0].shape(axes_.back()) - 1, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask =
where((first < indices) & (indices < last), half, one, stream());
array last(N - 1, indices.dtype());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
if (inverse_) {
auto starts = Shape(in.ndim(), 0);
auto stops = in.shape();
array one(1 / n_elements, in.dtype());
array two(2 / n_elements, in.dtype());
array mask =
where((first < indices) & (indices < last), two, one, stream());
return {
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
} else {
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape(ax));
}
array one(1, complex64);
array half(0.5, complex64);
array mask =
where((first < indices) & (indices < last), half, one, stream());
return {multiply(
fft::irfftn(
multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
}
} else if (inverse_) {
return {multiply(
fft::fftn(cotangents[0], axes, stream()),