Three ffts done one to go

This commit is contained in:
Angelos Katharopoulos 2025-05-12 16:19:12 -07:00
parent 23417cee8e
commit 043496bc9b

View File

@ -1952,6 +1952,14 @@ std::vector<array> FFT::vjp(
assert(argnums.size() == 1); assert(argnums.size() == 1);
auto& in = primals[0]; auto& in = primals[0];
std::vector<int> axes(axes_.begin(), axes_.end()); std::vector<int> axes(axes_.begin(), axes_.end());
// TODO: Add it as an option to do an unnormalized or scaled fft so that this
// isn't part of the graph.
double n_elements = 1;
for (auto ax : axes) {
n_elements *= primals[0].shape(ax);
}
if (real_ && inverse_) { if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream()); auto out = fft::fftn(cotangents[0], axes, stream());
auto start = Shape(out.ndim(), 0); auto start = Shape(out.ndim(), 0);
@ -1968,14 +1976,38 @@ std::vector<array> FFT::vjp(
} else if (real_) { } else if (real_) {
Shape n; Shape n;
for (auto ax : axes_) { for (auto ax : axes_) {
n.push_back(in.shape()[ax]); n.push_back(in.shape(ax));
} }
return {astype(
fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; // Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets halved.
Shape c(in.ndim(), 1);
c[axes_.back()] = cotangents[0].shape(axes_.back());
array indices = reshape(
arange(cotangents[0].shape(axes_.back()), stream()),
std::move(c),
stream());
array first(0, indices.dtype());
array last(cotangents[0].shape(axes_.back()) - 1, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask =
where((first < indices) & (indices < last), half, one, stream());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
} else if (inverse_) { } else if (inverse_) {
return {fft::ifftn(cotangents[0], axes, stream())}; return {multiply(
fft::fftn(cotangents[0], axes, stream()),
array(1 / n_elements, complex64),
stream())};
} else { } else {
return {fft::fftn(cotangents[0], axes, stream())}; return {multiply(
fft::ifftn(cotangents[0], axes, stream()),
array(n_elements, complex64),
stream())};
} }
} }