Fix some complex vjps (#2178)

This commit is contained in:
Angelos Katharopoulos
2025-05-14 23:37:12 -07:00
committed by GitHub
parent 130df35e1b
commit cf6c939e86
4 changed files with 166 additions and 42 deletions

View File

@@ -7,6 +7,13 @@ import mlx.core as mx
import mlx_tests
import numpy as np
try:
import torch
has_torch = True
except ImportError as e:
has_torch = False
class TestFFT(mlx_tests.MLXTestCase):
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
@@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase):
x = mx.array([])
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
@unittest.skipIf(not has_torch, "requires PyTorch")
def test_fft_grads(self):
real = [True, False]
inverse = [True, False]
axes = [
(-1,),
(-2, -1),
]
shapes = [
(4, 4),
(2, 4),
(2, 7),
(7, 7),
]
mxffts = {
(True, True): mx.fft.irfftn,
(True, False): mx.fft.rfftn,
(False, True): mx.fft.ifftn,
(False, False): mx.fft.fftn,
}
tffts = {
(True, True): torch.fft.irfftn,
(True, False): torch.fft.rfftn,
(False, True): torch.fft.ifftn,
(False, False): torch.fft.fftn,
}
for r, i, ax, sh in itertools.product(real, inverse, axes, shapes):
def f(x):
y = mxffts[r, i](x)
return (mx.abs(y) ** 2).sum()
def g(x):
y = tffts[r, i](x)
return (torch.abs(y) ** 2).sum()
if r and not i:
x = mx.random.normal(sh)
else:
x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze()
fx = f(x)
gx = g(torch.tensor(x))
self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4)
dfdx = mx.grad(f)(x)
dgdx = torch.func.grad(g)(torch.tensor(x))
self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4)
if __name__ == "__main__":
unittest.main()