From e8ac6bd2f53ea2fa6df1221d2f4c595ac2c7069c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Apr 2025 10:25:55 -0700 Subject: [PATCH] irfft throws instead of segfaults on scalars (#2109) --- mlx/fft.cpp | 3 +-- python/tests/test_fft.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index f0d41bf0f..02878af9c 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include @@ -109,7 +108,7 @@ array fft_impl( for (auto ax : axes) { n.push_back(a.shape(ax)); } - if (real && inverse) { + if (real && inverse && a.ndim() > 0) { n.back() = (n.back() - 1) * 2; } return fft_impl(a, n, axes, real, inverse, s); diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index ec9a48f00..c887cd968 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase): r_np = np.fft.ifft(segment, n=n_fft) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) + def test_fft_throws(self): + x = mx.array(3.0) + with self.assertRaises(ValueError): + mx.fft.irfftn(x) + if __name__ == "__main__": unittest.main()