irfft throws instead of segfaults on scalars (#2109)

This commit is contained in:
Awni Hannun
2025-04-22 10:25:55 -07:00
committed by GitHub
parent fdadc4f22c
commit e8ac6bd2f5
2 changed files with 6 additions and 2 deletions

View File

@@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase):
r_np = np.fft.ifft(segment, n=n_fft)
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
def test_fft_throws(self):
x = mx.array(3.0)
with self.assertRaises(ValueError):
mx.fft.irfftn(x)
if __name__ == "__main__":
unittest.main()