mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix the last fft vjp
This commit is contained in:
parent
043496bc9b
commit
e1c65e1381
@ -1957,47 +1957,46 @@ std::vector<array> FFT::vjp(
|
|||||||
// isn't part of the graph.
|
// isn't part of the graph.
|
||||||
double n_elements = 1;
|
double n_elements = 1;
|
||||||
for (auto ax : axes) {
|
for (auto ax : axes) {
|
||||||
n_elements *= primals[0].shape(ax);
|
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (real_ && inverse_) {
|
if (real_) {
|
||||||
auto out = fft::fftn(cotangents[0], axes, stream());
|
|
||||||
auto start = Shape(out.ndim(), 0);
|
|
||||||
auto stop = in.shape();
|
|
||||||
out = slice(out, start, stop, stream());
|
|
||||||
auto mask_shape = out.shape();
|
|
||||||
mask_shape[axes_.back()] -= 2;
|
|
||||||
auto mask = full(mask_shape, 2.0f, stream());
|
|
||||||
auto pad_shape = out.shape();
|
|
||||||
pad_shape[axes_.back()] = 1;
|
|
||||||
auto pad = full(pad_shape, 1.0f, stream());
|
|
||||||
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
|
|
||||||
return {multiply(mask, out, stream())};
|
|
||||||
} else if (real_) {
|
|
||||||
Shape n;
|
|
||||||
for (auto ax : axes_) {
|
|
||||||
n.push_back(in.shape(ax));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make a mask to account for the double use in the forward pass.
|
// Make a mask to account for the double use in the forward pass.
|
||||||
// Everything except the DC and nyquist frequencies gets halved.
|
// Everything except the DC and nyquist frequencies gets halved or doubled.
|
||||||
|
int N =
|
||||||
|
inverse_ ? in.shape(axes_.back()) : cotangents[0].shape(axes_.back());
|
||||||
Shape c(in.ndim(), 1);
|
Shape c(in.ndim(), 1);
|
||||||
c[axes_.back()] = cotangents[0].shape(axes_.back());
|
c[axes_.back()] = N;
|
||||||
array indices = reshape(
|
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
||||||
arange(cotangents[0].shape(axes_.back()), stream()),
|
|
||||||
std::move(c),
|
|
||||||
stream());
|
|
||||||
array first(0, indices.dtype());
|
array first(0, indices.dtype());
|
||||||
array last(cotangents[0].shape(axes_.back()) - 1, indices.dtype());
|
array last(N - 1, indices.dtype());
|
||||||
array one(1, complex64);
|
|
||||||
array half(0.5, complex64);
|
|
||||||
array mask =
|
|
||||||
where((first < indices) & (indices < last), half, one, stream());
|
|
||||||
|
|
||||||
return {multiply(
|
if (inverse_) {
|
||||||
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
auto starts = Shape(in.ndim(), 0);
|
||||||
array(n_elements, in.dtype()),
|
auto stops = in.shape();
|
||||||
stream())};
|
|
||||||
|
array one(1 / n_elements, in.dtype());
|
||||||
|
array two(2 / n_elements, in.dtype());
|
||||||
|
array mask =
|
||||||
|
where((first < indices) & (indices < last), two, one, stream());
|
||||||
|
|
||||||
|
return {
|
||||||
|
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
||||||
|
} else {
|
||||||
|
Shape n;
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
n.push_back(in.shape(ax));
|
||||||
|
}
|
||||||
|
array one(1, complex64);
|
||||||
|
array half(0.5, complex64);
|
||||||
|
array mask =
|
||||||
|
where((first < indices) & (indices < last), half, one, stream());
|
||||||
|
return {multiply(
|
||||||
|
fft::irfftn(
|
||||||
|
multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
||||||
|
array(n_elements, in.dtype()),
|
||||||
|
stream())};
|
||||||
|
}
|
||||||
} else if (inverse_) {
|
} else if (inverse_) {
|
||||||
return {multiply(
|
return {multiply(
|
||||||
fft::fftn(cotangents[0], axes, stream()),
|
fft::fftn(cotangents[0], axes, stream()),
|
||||||
|
Loading…
Reference in New Issue
Block a user