axes have to be iterable

This commit is contained in:
Aashiq Dheeraj 2025-04-29 20:59:05 -04:00
parent 7c9746a0bc
commit 9cfe0b1533

View File

@ -231,17 +231,17 @@ class TestFFT(mlx_tests.MLXTestCase):
# Test with specific axis # Test with specific axis
r = np.random.rand(4, 6).astype(np.float32) r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1)) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])
# Test with negative axes # Test with negative axes
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])
# Test with odd lengths # Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32) r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,)) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
# Test with complex input # Test with complex input
r = np.random.rand(8, 8).astype(np.float32) r = np.random.rand(8, 8).astype(np.float32)