fix fft bug (#2062)

This commit is contained in:
Awni Hannun
2025-04-10 19:41:27 -07:00
committed by GitHub
parent ddaa4b7dcb
commit ef7ece9851
3 changed files with 52 additions and 29 deletions

View File

@@ -182,6 +182,18 @@ class TestFFT(mlx_tests.MLXTestCase):
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
def test_fft_into_ifft(self):
n_fft = 8193
mx.random.seed(0)
segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal(
shape=(1, n_fft)
)
segment = mx.fft.fft(segment, n=n_fft)
r = mx.fft.ifft(segment, n=n_fft)
r_np = np.fft.ifft(segment, n=n_fft)
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
if __name__ == "__main__":
unittest.main()