axes have to be iterable

This commit is contained in:
Aashiq Dheeraj 2025-04-29 20:59:05 -04:00
parent 7c9746a0bc
commit 9cfe0b1533

View File

@ -231,17 +231,17 @@ class TestFFT(mlx_tests.MLXTestCase):
# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1))
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])
# Test with negative axes
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])
# Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,))
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
# Test with complex input
r = np.random.rand(8, 8).astype(np.float32)