diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 906767c30..dbfb5d382 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1952,6 +1952,14 @@ std::vector FFT::vjp( assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); + + // TODO: Add it as an option to do an unnormalized or scaled fft so that this + // isn't part of the graph. + double n_elements = 1; + for (auto ax : axes) { + n_elements *= primals[0].shape(ax); + } + if (real_ && inverse_) { auto out = fft::fftn(cotangents[0], axes, stream()); auto start = Shape(out.ndim(), 0); @@ -1968,14 +1976,38 @@ std::vector FFT::vjp( } else if (real_) { Shape n; for (auto ax : axes_) { - n.push_back(in.shape()[ax]); + n.push_back(in.shape(ax)); } - return {astype( - fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; + + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + Shape c(in.ndim(), 1); + c[axes_.back()] = cotangents[0].shape(axes_.back()); + array indices = reshape( + arange(cotangents[0].shape(axes_.back()), stream()), + std::move(c), + stream()); + array first(0, indices.dtype()); + array last(cotangents[0].shape(axes_.back()) - 1, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = + where((first < indices) & (indices < last), half, one, stream()); + + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { - return {fft::ifftn(cotangents[0], axes, stream())}; + return {multiply( + fft::fftn(cotangents[0], axes, stream()), + array(1 / n_elements, complex64), + stream())}; } else { - return {fft::fftn(cotangents[0], axes, stream())}; + return {multiply( + fft::ifftn(cotangents[0], axes, stream()), + array(n_elements, complex64), + stream())}; } }