diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index dbfb5d382..6a006d99f 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1957,47 +1957,46 @@ std::vector FFT::vjp( // isn't part of the graph. double n_elements = 1; for (auto ax : axes) { - n_elements *= primals[0].shape(ax); + n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); } - if (real_ && inverse_) { - auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = Shape(out.ndim(), 0); - auto stop = in.shape(); - out = slice(out, start, stop, stream()); - auto mask_shape = out.shape(); - mask_shape[axes_.back()] -= 2; - auto mask = full(mask_shape, 2.0f, stream()); - auto pad_shape = out.shape(); - pad_shape[axes_.back()] = 1; - auto pad = full(pad_shape, 1.0f, stream()); - mask = concatenate({pad, mask, pad}, axes_.back(), stream()); - return {multiply(mask, out, stream())}; - } else if (real_) { - Shape n; - for (auto ax : axes_) { - n.push_back(in.shape(ax)); - } - + if (real_) { // Make a mask to account for the double use in the forward pass. - // Everything except the DC and nyquist frequencies gets halved. + // Everything except the DC and nyquist frequencies gets halved or doubled. + int N = + inverse_ ? in.shape(axes_.back()) : cotangents[0].shape(axes_.back()); Shape c(in.ndim(), 1); - c[axes_.back()] = cotangents[0].shape(axes_.back()); - array indices = reshape( - arange(cotangents[0].shape(axes_.back()), stream()), - std::move(c), - stream()); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); array first(0, indices.dtype()); - array last(cotangents[0].shape(axes_.back()) - 1, indices.dtype()); - array one(1, complex64); - array half(0.5, complex64); - array mask = - where((first < indices) & (indices < last), half, one, stream()); + array last(N - 1, indices.dtype()); - return {multiply( - fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), - array(n_elements, in.dtype()), - stream())}; + if (inverse_) { + auto starts = Shape(in.ndim(), 0); + auto stops = in.shape(); + + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = + where((first < indices) & (indices < last), two, one, stream()); + + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; + } else { + Shape n; + for (auto ax : axes_) { + n.push_back(in.shape(ax)); + } + array one(1, complex64); + array half(0.5, complex64); + array mask = + where((first < indices) & (indices < last), half, one, stream()); + return {multiply( + fft::irfftn( + multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; + } } else if (inverse_) { return {multiply( fft::fftn(cotangents[0], axes, stream()),