From 194f1adbd878d3d45e25bb6add5c595cbfa2455a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 12 May 2025 22:27:12 -0700 Subject: [PATCH] Fixx rfft odd grad and add tests --- mlx/primitives.cpp | 63 ++++++++++++++++++++-------------------- python/tests/test_fft.py | 57 ++++++++++++++++++++++++++++++++++++ tests/autograd_tests.cpp | 2 +- 3 files changed, 90 insertions(+), 32 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6a006d99f..c36a61686 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1960,43 +1960,44 @@ std::vector FFT::vjp( n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); } - if (real_) { + if (real_ && inverse_) { // Make a mask to account for the double use in the forward pass. - // Everything except the DC and nyquist frequencies gets halved or doubled. - int N = - inverse_ ? in.shape(axes_.back()) : cotangents[0].shape(axes_.back()); + // Everything except the DC and nyquist frequencies gets doubled. + int N = in.shape(axes_.back()); + bool odd = cotangents[0].shape(axes_.back()) % 2; Shape c(in.ndim(), 1); c[axes_.back()] = N; array indices = reshape(arange(N, stream()), std::move(c), stream()); array first(0, indices.dtype()); - array last(N - 1, indices.dtype()); - - if (inverse_) { - auto starts = Shape(in.ndim(), 0); - auto stops = in.shape(); - - array one(1 / n_elements, in.dtype()); - array two(2 / n_elements, in.dtype()); - array mask = - where((first < indices) & (indices < last), two, one, stream()); - - return { - multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; - } else { - Shape n; - for (auto ax : axes_) { - n.push_back(in.shape(ax)); - } - array one(1, complex64); - array half(0.5, complex64); - array mask = - where((first < indices) & (indices < last), half, one, stream()); - return {multiply( - fft::irfftn( - multiply(cotangents[0], mask, stream()), n, axes, stream()), - array(n_elements, in.dtype()), - stream())}; + array last(N - 1 + odd, indices.dtype()); + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = + where((first < indices) & (indices < last), two, one, stream()); + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; + } else if (real_) { + Shape n; + for (auto ax : axes_) { + n.push_back(in.shape(ax)); } + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + int N = cotangents[0].shape(axes_.back()); + bool odd = in.shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = + where((first < indices) & (indices < last), half, one, stream()); + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { return {multiply( fft::fftn(cotangents[0], axes, stream()), diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index f644944c7..df9d25edc 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -7,6 +7,13 @@ import mlx.core as mx import mlx_tests import numpy as np +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + class TestFFT(mlx_tests.MLXTestCase): def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): @@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase): x = mx.array([]) self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_fft_grads(self): + real = [True, False] + inverse = [True, False] + axes = [ + (-1,), + (-2, -1), + ] + shapes = [ + (4, 4), + (2, 4), + (2, 7), + (7, 7), + ] + + mxffts = { + (True, True): mx.fft.irfftn, + (True, False): mx.fft.rfftn, + (False, True): mx.fft.ifftn, + (False, False): mx.fft.fftn, + } + tffts = { + (True, True): torch.fft.irfftn, + (True, False): torch.fft.rfftn, + (False, True): torch.fft.ifftn, + (False, False): torch.fft.fftn, + } + + for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): + + def f(x): + y = mxffts[r, i](x) + return (mx.abs(y) ** 2).sum() + + def g(x): + y = tffts[r, i](x) + return (torch.abs(y) ** 2).sum() + + if r and not i: + x = mx.random.normal(sh) + else: + x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + fx = f(x) + gx = g(torch.tensor(x)) + self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) + + dfdx = mx.grad(f)(x) + dgdx = torch.func.grad(g)(torch.tensor(x)) + self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..8b7126714 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1149,7 +1149,7 @@ TEST_CASE("test complex gradients") { auto cotan = array(complex64_t{2.0, 3.0}); out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), -8.0); + CHECK_EQ(out.item(), 16.0); out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; CHECK_EQ(out.item(), complex64_t{6.0, 9.0});