From 9cfe0b1533daaa4857a92cf5491dafc06ed463fa Mon Sep 17 00:00:00 2001 From: Aashiq Dheeraj Date: Tue, 29 Apr 2025 20:59:05 -0400 Subject: [PATCH] axes have to be iterable --- python/tests/test_fft.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 6eaf663c8..f644944c7 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -231,17 +231,17 @@ class TestFFT(mlx_tests.MLXTestCase): # Test with specific axis r = np.random.rand(4, 6).astype(np.float32) - self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0) - self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1) - self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1)) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) # Test with negative axes - self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) # Test with odd lengths r = np.random.rand(5, 7).astype(np.float32) self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) - self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,)) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) # Test with complex input r = np.random.rand(8, 8).astype(np.float32)