mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
axes have to be iterable
This commit is contained in:
parent
7c9746a0bc
commit
9cfe0b1533
@ -231,17 +231,17 @@ class TestFFT(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test with specific axis
|
||||
r = np.random.rand(4, 6).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1))
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])
|
||||
|
||||
# Test with negative axes
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])
|
||||
|
||||
# Test with odd lengths
|
||||
r = np.random.rand(5, 7).astype(np.float32)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,))
|
||||
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
|
||||
|
||||
# Test with complex input
|
||||
r = np.random.rand(8, 8).astype(np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user