From cf6c939e868f6db3421396fda3fde31708e6f1eb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 14 May 2025 23:37:12 -0700 Subject: [PATCH] Fix some complex vjps (#2178) --- mlx/primitives.cpp | 89 ++++++++++++++++++++++++++++++---------- python/tests/test_fft.py | 57 +++++++++++++++++++++++++ tests/autograd_tests.cpp | 46 +++++++++++++++------ tests/fft_tests.cpp | 16 ++++---- 4 files changed, 166 insertions(+), 42 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 03ca06bdd..e1924e66c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1488,14 +1488,16 @@ std::vector Divide::vjp( const std::vector& argnums, const std::vector&) { std::vector vjps; + array denominator_bar = conjugate(primals[1], stream()); for (auto arg : argnums) { if (arg == 0) { - vjps.push_back(divide(cotangents[0], primals[1], stream())); + vjps.push_back(divide(cotangents[0], denominator_bar, stream())); } else { vjps.push_back(negative( divide( - multiply(cotangents[0], primals[0], stream()), - square(primals[1], stream()), + multiply( + cotangents[0], conjugate(primals[0], stream()), stream()), + square(denominator_bar, stream()), stream()), stream())); } @@ -1950,30 +1952,74 @@ std::vector FFT::vjp( assert(argnums.size() == 1); auto& in = primals[0]; std::vector axes(axes_.begin(), axes_.end()); + + // TODO: Add it as an option to do an unnormalized or scaled fft so that this + // isn't part of the graph. + double n_elements = 1; + for (auto ax : axes) { + n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax); + } + if (real_ && inverse_) { - auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = Shape(out.ndim(), 0); - auto stop = in.shape(); - out = slice(out, start, stop, stream()); - auto mask_shape = out.shape(); - mask_shape[axes_.back()] -= 2; - auto mask = full(mask_shape, 2.0f, stream()); - auto pad_shape = out.shape(); - pad_shape[axes_.back()] = 1; - auto pad = full(pad_shape, 1.0f, stream()); - mask = concatenate({pad, mask, pad}, axes_.back(), stream()); - return {multiply(mask, out, stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets doubled. + int N = in.shape(axes_.back()); + bool odd = cotangents[0].shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1 / n_elements, in.dtype()); + array two(2 / n_elements, in.dtype()); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + two, + one, + stream()); + return { + multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())}; } else if (real_) { Shape n; for (auto ax : axes_) { - n.push_back(in.shape()[ax]); + n.push_back(in.shape(ax)); } - return {astype( - fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())}; + // Make a mask to account for the double use in the forward pass. + // Everything except the DC and nyquist frequencies gets halved. + int N = cotangents[0].shape(axes_.back()); + bool odd = in.shape(axes_.back()) % 2; + Shape c(in.ndim(), 1); + c[axes_.back()] = N; + array indices = reshape(arange(N, stream()), std::move(c), stream()); + array first(0, indices.dtype()); + array last(N - 1 + odd, indices.dtype()); + array one(1, complex64); + array half(0.5, complex64); + array mask = where( + logical_and( + greater(indices, first, stream()), + less(indices, last, stream()), + stream()), + half, + one, + stream()); + return {multiply( + fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()), + array(n_elements, in.dtype()), + stream())}; } else if (inverse_) { - return {fft::ifftn(cotangents[0], axes, stream())}; + return {multiply( + fft::fftn(cotangents[0], axes, stream()), + array(1 / n_elements, complex64), + stream())}; } else { - return {fft::fftn(cotangents[0], axes, stream())}; + return {multiply( + fft::ifftn(cotangents[0], axes, stream()), + array(n_elements, complex64), + stream())}; } } @@ -2776,7 +2822,8 @@ std::vector Multiply::vjp( const std::vector&) { std::vector vjps; for (auto arg : argnums) { - vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream())); + vjps.push_back(multiply( + conjugate(primals[1 - arg], stream()), cotangents[0], stream())); } return vjps; } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index f644944c7..df9d25edc 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -7,6 +7,13 @@ import mlx.core as mx import mlx_tests import numpy as np +try: + import torch + + has_torch = True +except ImportError as e: + has_torch = False + class TestFFT(mlx_tests.MLXTestCase): def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs): @@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase): x = mx.array([]) self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_fft_grads(self): + real = [True, False] + inverse = [True, False] + axes = [ + (-1,), + (-2, -1), + ] + shapes = [ + (4, 4), + (2, 4), + (2, 7), + (7, 7), + ] + + mxffts = { + (True, True): mx.fft.irfftn, + (True, False): mx.fft.rfftn, + (False, True): mx.fft.ifftn, + (False, False): mx.fft.fftn, + } + tffts = { + (True, True): torch.fft.irfftn, + (True, False): torch.fft.rfftn, + (False, True): torch.fft.ifftn, + (False, False): torch.fft.fftn, + } + + for r, i, ax, sh in itertools.product(real, inverse, axes, shapes): + + def f(x): + y = mxffts[r, i](x) + return (mx.abs(y) ** 2).sum() + + def g(x): + y = tffts[r, i](x) + return (torch.abs(y) ** 2).sum() + + if r and not i: + x = mx.random.normal(sh) + else: + x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze() + fx = f(x) + gx = g(torch.tensor(x)) + self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4) + + dfdx = mx.grad(f)(x) + dgdx = torch.func.grad(g)(torch.tensor(x)) + self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index c992c3c6d..5b3454bfc 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") { } { + auto multiply_fn = + [](const std::vector& inputs) -> std::vector { + return {multiply(inputs[0], inputs[1])}; + }; + // Compute jvp auto x = array(complex64_t{2.0, 4.0}); auto y = array(3.0f); - auto x_tan = array(complex64_t{1.0, 2.0}); auto y_tan = array(2.0f); + auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ(jvp_out[0].item(), complex64_t{7.0, 14.0}); - auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second; - CHECK_EQ(out.item(), complex64_t{4.0, 8.0}); - - out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second; - CHECK_EQ(out.item(), complex64_t{3.0, 6.0}); - + // Compute vjp auto cotan = array(complex64_t{2.0, 3.0}); - out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second; - CHECK_EQ(out.dtype(), float32); - CHECK_EQ(out.item(), -8.0); + auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].dtype(), complex64); + CHECK_EQ(vjp_out[0].item(), complex64_t{6.0, 9.0}); + CHECK_EQ(vjp_out[1].dtype(), float32); + CHECK_EQ(vjp_out[1].item(), 16); + } - out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second; - CHECK_EQ(out.item(), complex64_t{6.0, 9.0}); + { + auto divide_fn = + [](const std::vector& inputs) -> std::vector { + return {divide(inputs[0], inputs[1])}; + }; + + // Compute jvp + auto x = array(complex64_t{2.0, 3.0}); + auto y = array(complex64_t{1.0, 2.0}); + auto x_tan = array(complex64_t{3.0, 4.0}); + auto y_tan = array(complex64_t{4.0, -2.0}); + auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second; + CHECK_EQ( + jvp_out[0].item(), doctest::Approx(complex64_t{2.6, 2.8})); + + // Compute vjp + auto cotan = array(complex64_t{2.0, -4.0}); + auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second; + CHECK_EQ(vjp_out[0].item(), complex64_t{2.0, 0.0}); + CHECK_EQ(vjp_out[1].item(), complex64_t{-3.2, -0.4}); } } diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 0db3999c8..b9e2d1bcc 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -243,7 +243,7 @@ TEST_CASE("test fft grads") { auto fft_fn = [](array x) { return fft::fft(x); }; auto cotangent = astype(arange(10), complex64); auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::fft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item()); auto tangent = astype(arange(10), complex64); auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second; @@ -252,7 +252,7 @@ TEST_CASE("test fft grads") { // Inverse auto ifft_fn = [](array x) { return fft::ifft(x); }; vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second; - CHECK(array_equal(fft::ifft(cotangent), vjp_out).item()); + CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item()); jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second; CHECK(array_equal(fft::ifft(tangent), jvp_out).item()); @@ -261,7 +261,8 @@ TEST_CASE("test fft grads") { auto rfft_fn = [](array x) { return fft::rfft(x); }; cotangent = astype(arange(6), complex64); vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second; - auto expected = astype(fft::fft(cotangent, 10, 0), float32); + array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64); + auto expected = fft::irfft(cotangent * mask, 10, 0) * 10; CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), float32); @@ -272,12 +273,9 @@ TEST_CASE("test fft grads") { auto irfft_fn = [](array x) { return fft::irfft(x); }; cotangent = astype(arange(10), float32); vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second; - expected = fft::fft(cotangent, 10, 0); - auto o_splits = split(vjp_out, {1, 5}); - auto e_splits = split(expected, {1, 5, 6}); - CHECK_EQ(e_splits[0].item(), o_splits[0].item()); - CHECK(array_equal(2 * e_splits[1], o_splits[1]).item()); - CHECK_EQ(e_splits[2].item(), o_splits[2].item()); + mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32); + expected = fft::rfft(cotangent) * mask; + CHECK(array_equal(expected, vjp_out).item()); tangent = astype(arange(10), complex64); jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;