Fix some complex vjps (#2178)

This commit is contained in:
Angelos Katharopoulos 2025-05-14 23:37:12 -07:00 committed by GitHub
parent 130df35e1b
commit cf6c939e86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 42 deletions

View File

@ -1488,14 +1488,16 @@ std::vector<array> Divide::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
array denominator_bar = conjugate(primals[1], stream());
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(divide(cotangents[0], primals[1], stream()));
vjps.push_back(divide(cotangents[0], denominator_bar, stream()));
} else {
vjps.push_back(negative(
divide(
multiply(cotangents[0], primals[0], stream()),
square(primals[1], stream()),
multiply(
cotangents[0], conjugate(primals[0], stream()), stream()),
square(denominator_bar, stream()),
stream()),
stream()));
}
@ -1950,30 +1952,74 @@ std::vector<array> FFT::vjp(
assert(argnums.size() == 1);
auto& in = primals[0];
std::vector<int> axes(axes_.begin(), axes_.end());
// TODO: Add it as an option to do an unnormalized or scaled fft so that this
// isn't part of the graph.
double n_elements = 1;
for (auto ax : axes) {
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
}
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
mask_shape[axes_.back()] -= 2;
auto mask = full(mask_shape, 2.0f, stream());
auto pad_shape = out.shape();
pad_shape[axes_.back()] = 1;
auto pad = full(pad_shape, 1.0f, stream());
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets doubled.
int N = in.shape(axes_.back());
bool odd = cotangents[0].shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1 / n_elements, in.dtype());
array two(2 / n_elements, in.dtype());
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
two,
one,
stream());
return {
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
} else if (real_) {
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape()[ax]);
n.push_back(in.shape(ax));
}
return {astype(
fft::fftn(cotangents[0], n, axes, stream()), in.dtype(), stream())};
// Make a mask to account for the double use in the forward pass.
// Everything except the DC and nyquist frequencies gets halved.
int N = cotangents[0].shape(axes_.back());
bool odd = in.shape(axes_.back()) % 2;
Shape c(in.ndim(), 1);
c[axes_.back()] = N;
array indices = reshape(arange(N, stream()), std::move(c), stream());
array first(0, indices.dtype());
array last(N - 1 + odd, indices.dtype());
array one(1, complex64);
array half(0.5, complex64);
array mask = where(
logical_and(
greater(indices, first, stream()),
less(indices, last, stream()),
stream()),
half,
one,
stream());
return {multiply(
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
array(n_elements, in.dtype()),
stream())};
} else if (inverse_) {
return {fft::ifftn(cotangents[0], axes, stream())};
return {multiply(
fft::fftn(cotangents[0], axes, stream()),
array(1 / n_elements, complex64),
stream())};
} else {
return {fft::fftn(cotangents[0], axes, stream())};
return {multiply(
fft::ifftn(cotangents[0], axes, stream()),
array(n_elements, complex64),
stream())};
}
}
@ -2776,7 +2822,8 @@ std::vector<array> Multiply::vjp(
const std::vector<array>&) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
vjps.push_back(multiply(
conjugate(primals[1 - arg], stream()), cotangents[0], stream()));
}
return vjps;
}

View File

@ -7,6 +7,13 @@ import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
has_torch = True
except ImportError as e:
has_torch = False
class TestFFT(mlx_tests.MLXTestCase):
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase):
x = mx.array([])
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_fft_grads(self):
real = [True, False]
inverse = [True, False]
axes = [
(-1,),
(-2, -1),
]
shapes = [
(4, 4),
(2, 4),
(2, 7),
(7, 7),
]
mxffts = {
(True, True): mx.fft.irfftn,
(True, False): mx.fft.rfftn,
(False, True): mx.fft.ifftn,
(False, False): mx.fft.fftn,
}
tffts = {
(True, True): torch.fft.irfftn,
(True, False): torch.fft.rfftn,
(False, True): torch.fft.ifftn,
(False, False): torch.fft.fftn,
}
for r, i, ax, sh in itertools.product(real, inverse, axes, shapes):
def f(x):
y = mxffts[r, i](x)
return (mx.abs(y) ** 2).sum()
def g(x):
y = tffts[r, i](x)
return (torch.abs(y) ** 2).sum()
if r and not i:
x = mx.random.normal(sh)
else:
x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze()
fx = f(x)
gx = g(torch.tensor(x))
self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4)
dfdx = mx.grad(f)(x)
dgdx = torch.func.grad(g)(torch.tensor(x))
self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4)
if __name__ == "__main__":
unittest.main()

View File

@ -1133,26 +1133,48 @@ TEST_CASE("test complex gradients") {
}
{
auto multiply_fn =
[](const std::vector<array>& inputs) -> std::vector<array> {
return {multiply(inputs[0], inputs[1])};
};
// Compute jvp
auto x = array(complex64_t{2.0, 4.0});
auto y = array(3.0f);
auto x_tan = array(complex64_t{1.0, 2.0});
auto y_tan = array(2.0f);
auto jvp_out = jvp(multiply_fn, {x, y}, {x_tan, y_tan}).second;
CHECK_EQ(jvp_out[0].item<complex64_t>(), complex64_t{7.0, 14.0});
auto out = jvp([x](array a) { return multiply(a, x); }, y, y_tan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{4.0, 8.0});
out = jvp([y](array a) { return multiply(a, y); }, x, x_tan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{3.0, 6.0});
// Compute vjp
auto cotan = array(complex64_t{2.0, 3.0});
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
CHECK_EQ(out.dtype(), float32);
CHECK_EQ(out.item<float>(), -8.0);
auto vjp_out = vjp(multiply_fn, {x, y}, {cotan}).second;
CHECK_EQ(vjp_out[0].dtype(), complex64);
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{6.0, 9.0});
CHECK_EQ(vjp_out[1].dtype(), float32);
CHECK_EQ(vjp_out[1].item<float>(), 16);
}
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
{
auto divide_fn =
[](const std::vector<array>& inputs) -> std::vector<array> {
return {divide(inputs[0], inputs[1])};
};
// Compute jvp
auto x = array(complex64_t{2.0, 3.0});
auto y = array(complex64_t{1.0, 2.0});
auto x_tan = array(complex64_t{3.0, 4.0});
auto y_tan = array(complex64_t{4.0, -2.0});
auto jvp_out = jvp(divide_fn, {x, y}, {x_tan, y_tan}).second;
CHECK_EQ(
jvp_out[0].item<complex64_t>(), doctest::Approx(complex64_t{2.6, 2.8}));
// Compute vjp
auto cotan = array(complex64_t{2.0, -4.0});
auto vjp_out = vjp(divide_fn, {x, y}, {cotan}).second;
CHECK_EQ(vjp_out[0].item<complex64_t>(), complex64_t{2.0, 0.0});
CHECK_EQ(vjp_out[1].item<complex64_t>(), complex64_t{-3.2, -0.4});
}
}

View File

@ -243,7 +243,7 @@ TEST_CASE("test fft grads") {
auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64);
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
CHECK(array_equal(fft::ifft(cotangent) * 10, vjp_out).item<bool>());
auto tangent = astype(arange(10), complex64);
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
@ -252,7 +252,7 @@ TEST_CASE("test fft grads") {
// Inverse
auto ifft_fn = [](array x) { return fft::ifft(x); };
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
CHECK(array_equal(fft::fft(cotangent) * 0.1, vjp_out).item<bool>());
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
@ -261,7 +261,8 @@ TEST_CASE("test fft grads") {
auto rfft_fn = [](array x) { return fft::rfft(x); };
cotangent = astype(arange(6), complex64);
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
array mask({1.0, 0.5, 0.5, 0.5, 0.5, 1.0}, complex64);
auto expected = fft::irfft(cotangent * mask, 10, 0) * 10;
CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), float32);
@ -272,12 +273,9 @@ TEST_CASE("test fft grads") {
auto irfft_fn = [](array x) { return fft::irfft(x); };
cotangent = astype(arange(10), float32);
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
expected = fft::fft(cotangent, 10, 0);
auto o_splits = split(vjp_out, {1, 5});
auto e_splits = split(expected, {1, 5, 6});
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
mask = array({0.1, 0.2, 0.2, 0.2, 0.2, 0.1}, float32);
expected = fft::rfft(cotangent) * mask;
CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), complex64);
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;