mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-04 16:21:14 +08:00
Fixx rfft odd grad and add tests
This commit is contained in:
parent
e1c65e1381
commit
194f1adbd8
@ -1960,43 +1960,44 @@ std::vector<array> FFT::vjp(
|
|||||||
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
|
n_elements *= inverse_ ? cotangents[0].shape(ax) : primals[0].shape(ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (real_) {
|
if (real_ && inverse_) {
|
||||||
// Make a mask to account for the double use in the forward pass.
|
// Make a mask to account for the double use in the forward pass.
|
||||||
// Everything except the DC and nyquist frequencies gets halved or doubled.
|
// Everything except the DC and nyquist frequencies gets doubled.
|
||||||
int N =
|
int N = in.shape(axes_.back());
|
||||||
inverse_ ? in.shape(axes_.back()) : cotangents[0].shape(axes_.back());
|
bool odd = cotangents[0].shape(axes_.back()) % 2;
|
||||||
Shape c(in.ndim(), 1);
|
Shape c(in.ndim(), 1);
|
||||||
c[axes_.back()] = N;
|
c[axes_.back()] = N;
|
||||||
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
||||||
array first(0, indices.dtype());
|
array first(0, indices.dtype());
|
||||||
array last(N - 1, indices.dtype());
|
array last(N - 1 + odd, indices.dtype());
|
||||||
|
array one(1 / n_elements, in.dtype());
|
||||||
if (inverse_) {
|
array two(2 / n_elements, in.dtype());
|
||||||
auto starts = Shape(in.ndim(), 0);
|
array mask =
|
||||||
auto stops = in.shape();
|
where((first < indices) & (indices < last), two, one, stream());
|
||||||
|
return {
|
||||||
array one(1 / n_elements, in.dtype());
|
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
||||||
array two(2 / n_elements, in.dtype());
|
} else if (real_) {
|
||||||
array mask =
|
Shape n;
|
||||||
where((first < indices) & (indices < last), two, one, stream());
|
for (auto ax : axes_) {
|
||||||
|
n.push_back(in.shape(ax));
|
||||||
return {
|
|
||||||
multiply(fft::rfftn(cotangents[0], axes, stream()), mask, stream())};
|
|
||||||
} else {
|
|
||||||
Shape n;
|
|
||||||
for (auto ax : axes_) {
|
|
||||||
n.push_back(in.shape(ax));
|
|
||||||
}
|
|
||||||
array one(1, complex64);
|
|
||||||
array half(0.5, complex64);
|
|
||||||
array mask =
|
|
||||||
where((first < indices) & (indices < last), half, one, stream());
|
|
||||||
return {multiply(
|
|
||||||
fft::irfftn(
|
|
||||||
multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
|
||||||
array(n_elements, in.dtype()),
|
|
||||||
stream())};
|
|
||||||
}
|
}
|
||||||
|
// Make a mask to account for the double use in the forward pass.
|
||||||
|
// Everything except the DC and nyquist frequencies gets halved.
|
||||||
|
int N = cotangents[0].shape(axes_.back());
|
||||||
|
bool odd = in.shape(axes_.back()) % 2;
|
||||||
|
Shape c(in.ndim(), 1);
|
||||||
|
c[axes_.back()] = N;
|
||||||
|
array indices = reshape(arange(N, stream()), std::move(c), stream());
|
||||||
|
array first(0, indices.dtype());
|
||||||
|
array last(N - 1 + odd, indices.dtype());
|
||||||
|
array one(1, complex64);
|
||||||
|
array half(0.5, complex64);
|
||||||
|
array mask =
|
||||||
|
where((first < indices) & (indices < last), half, one, stream());
|
||||||
|
return {multiply(
|
||||||
|
fft::irfftn(multiply(cotangents[0], mask, stream()), n, axes, stream()),
|
||||||
|
array(n_elements, in.dtype()),
|
||||||
|
stream())};
|
||||||
} else if (inverse_) {
|
} else if (inverse_) {
|
||||||
return {multiply(
|
return {multiply(
|
||||||
fft::fftn(cotangents[0], axes, stream()),
|
fft::fftn(cotangents[0], axes, stream()),
|
||||||
|
@ -7,6 +7,13 @@ import mlx.core as mx
|
|||||||
import mlx_tests
|
import mlx_tests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
has_torch = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_torch = False
|
||||||
|
|
||||||
|
|
||||||
class TestFFT(mlx_tests.MLXTestCase):
|
class TestFFT(mlx_tests.MLXTestCase):
|
||||||
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
|
def check_mx_np(self, op_mx, op_np, a_np, atol=1e-5, rtol=1e-6, **kwargs):
|
||||||
@ -261,6 +268,56 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
x = mx.array([])
|
x = mx.array([])
|
||||||
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
|
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires PyTorch")
|
||||||
|
def test_fft_grads(self):
|
||||||
|
real = [True, False]
|
||||||
|
inverse = [True, False]
|
||||||
|
axes = [
|
||||||
|
(-1,),
|
||||||
|
(-2, -1),
|
||||||
|
]
|
||||||
|
shapes = [
|
||||||
|
(4, 4),
|
||||||
|
(2, 4),
|
||||||
|
(2, 7),
|
||||||
|
(7, 7),
|
||||||
|
]
|
||||||
|
|
||||||
|
mxffts = {
|
||||||
|
(True, True): mx.fft.irfftn,
|
||||||
|
(True, False): mx.fft.rfftn,
|
||||||
|
(False, True): mx.fft.ifftn,
|
||||||
|
(False, False): mx.fft.fftn,
|
||||||
|
}
|
||||||
|
tffts = {
|
||||||
|
(True, True): torch.fft.irfftn,
|
||||||
|
(True, False): torch.fft.rfftn,
|
||||||
|
(False, True): torch.fft.ifftn,
|
||||||
|
(False, False): torch.fft.fftn,
|
||||||
|
}
|
||||||
|
|
||||||
|
for r, i, ax, sh in itertools.product(real, inverse, axes, shapes):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
y = mxffts[r, i](x)
|
||||||
|
return (mx.abs(y) ** 2).sum()
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
y = tffts[r, i](x)
|
||||||
|
return (torch.abs(y) ** 2).sum()
|
||||||
|
|
||||||
|
if r and not i:
|
||||||
|
x = mx.random.normal(sh)
|
||||||
|
else:
|
||||||
|
x = mx.random.normal((*sh, 2)).view(mx.complex64).squeeze()
|
||||||
|
fx = f(x)
|
||||||
|
gx = g(torch.tensor(x))
|
||||||
|
self.assertLess((fx - gx).abs().max() / gx.abs().mean(), 1e-4)
|
||||||
|
|
||||||
|
dfdx = mx.grad(f)(x)
|
||||||
|
dgdx = torch.func.grad(g)(torch.tensor(x))
|
||||||
|
self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1149,7 +1149,7 @@ TEST_CASE("test complex gradients") {
|
|||||||
auto cotan = array(complex64_t{2.0, 3.0});
|
auto cotan = array(complex64_t{2.0, 3.0});
|
||||||
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
|
out = vjp([x](array a) { return multiply(a, x); }, y, cotan).second;
|
||||||
CHECK_EQ(out.dtype(), float32);
|
CHECK_EQ(out.dtype(), float32);
|
||||||
CHECK_EQ(out.item<float>(), -8.0);
|
CHECK_EQ(out.item<float>(), 16.0);
|
||||||
|
|
||||||
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
|
out = vjp([y](array a) { return multiply(a, y); }, x, cotan).second;
|
||||||
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
|
CHECK_EQ(out.item<complex64_t>(), complex64_t{6.0, 9.0});
|
||||||
|
Loading…
Reference in New Issue
Block a user