add fftshift and ifftshift fft helpers

This commit is contained in:
Aashiq Dheeraj
2025-04-29 00:28:01 -04:00
committed by Aashiq Dheeraj
parent 99b9868859
commit 00e43d18ed
6 changed files with 264 additions and 0 deletions

View File

@@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.fft.irfftn(x)
def test_fftshift(self):
# Test 1D arrays
r = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1])
# Test with negative axes
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1])
# Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
# Test with complex input
r = np.random.rand(8, 8).astype(np.float32)
i = np.random.rand(8, 8).astype(np.float32)
c = r + 1j * i
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c)
def test_ifftshift(self):
# Test 1D arrays
r = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1))
# Test with negative axes
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1)
# Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,))
# Test with complex input
r = np.random.rand(8, 8).astype(np.float32)
i = np.random.rand(8, 8).astype(np.float32)
c = r + 1j * i
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c)
def test_fftshift_errors(self):
# Test invalid axes
x = mx.array(np.random.rand(4, 4).astype(np.float32))
with self.assertRaises(ValueError):
mx.fft.fftshift(x, axes=[2])
with self.assertRaises(ValueError):
mx.fft.fftshift(x, axes=[-3])
# Test empty array
x = mx.array([])
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
if __name__ == "__main__":
unittest.main()