fix FFT (PocketFFT requires size_t for axis)

This commit is contained in:
Ronan Collobert
2025-10-29 17:05:48 -07:00
parent 310e501e6a
commit 63d91557e0
3 changed files with 5 additions and 4 deletions

View File

@@ -27,7 +27,7 @@ array fft_impl(
return a;
}
std::vector<int> valid_axes;
std::vector<size_t> valid_axes;
for (int ax : axes) {
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
}

View File

@@ -2156,7 +2156,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
auto out_shape = in.shape();
if (ax >= 0) {
for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) {
if (static_cast<int>(fft_ax) >= ax) {
fft_ax++;
}
if (real_) {

View File

@@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive {
public:
explicit FFT(
Stream stream,
const std::vector<int>& axes,
// Note: PocketFFT requires size_t
const std::vector<size_t>& axes,
bool inverse,
bool real)
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
@@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive {
}
private:
std::vector<int> axes_;
std::vector<size_t> axes_;
bool inverse_;
bool real_;
};