irfft throws instead of segfaults on scalars (#2109)

This commit is contained in:
Awni Hannun 2025-04-22 10:25:55 -07:00 committed by GitHub
parent fdadc4f22c
commit e8ac6bd2f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 2 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <numeric> #include <numeric>
#include <set> #include <set>
@ -109,7 +108,7 @@ array fft_impl(
for (auto ax : axes) { for (auto ax : axes) {
n.push_back(a.shape(ax)); n.push_back(a.shape(ax));
} }
if (real && inverse) { if (real && inverse && a.ndim() > 0) {
n.back() = (n.back() - 1) * 2; n.back() = (n.back() - 1) * 2;
} }
return fft_impl(a, n, axes, real, inverse, s); return fft_impl(a, n, axes, real, inverse, s);

View File

@ -194,6 +194,11 @@ class TestFFT(mlx_tests.MLXTestCase):
r_np = np.fft.ifft(segment, n=n_fft) r_np = np.fft.ifft(segment, n=n_fft)
self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5)) self.assertTrue(np.allclose(r, r_np, atol=1e-5, rtol=1e-5))
def test_fft_throws(self):
x = mx.array(3.0)
with self.assertRaises(ValueError):
mx.fft.irfftn(x)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()