mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix fft bug (#2062)
This commit is contained in:
@@ -182,6 +182,18 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
out_np = np.abs(np.fft.fft(np.tile(np.reshape(np.array(b_np), (1, 4)), (4, 1))))
|
||||
np.testing.assert_allclose(out_mx, out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_fft_into_ifft(self):
|
||||
n_fft = 8193
|
||||
mx.random.seed(0)
|
||||
|
||||
segment = mx.random.normal(shape=[1, n_fft]) + 1j * mx.random.normal(
|
||||
shape=(1, n_fft)
|
||||
)
|
||||
segment = mx.fft.fft(segment, n=n_fft)
|
||||
r = mx.fft.ifft(segment, n=n_fft)
|
||||
r_np = np.fft.ifft(segment, n=n_fft)
|
||||
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user