irfft throws instead of segfaults on scalars (#2109)

This commit is contained in:
Awni Hannun
2025-04-22 10:25:55 -07:00
committed by GitHub
parent fdadc4f22c
commit e8ac6bd2f5
2 changed files with 6 additions and 2 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <set>
@@ -109,7 +108,7 @@ array fft_impl(
for (auto ax : axes) {
n.push_back(a.shape(ax));
}
if (real && inverse) {
if (real && inverse && a.ndim() > 0) {
n.back() = (n.back() - 1) * 2;
}
return fft_impl(a, n, axes, real, inverse, s);