fix FFT (PocketFFT requires size_t for axis)

This commit is contained in:
Ronan Collobert
2025-10-29 17:05:48 -07:00
parent 310e501e6a
commit 63d91557e0
3 changed files with 5 additions and 4 deletions

View File

@@ -27,7 +27,7 @@ array fft_impl(
return a; return a;
} }
std::vector<int> valid_axes; std::vector<size_t> valid_axes;
for (int ax : axes) { for (int ax : axes) {
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax); valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
} }

View File

@@ -2156,7 +2156,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
auto out_shape = in.shape(); auto out_shape = in.shape();
if (ax >= 0) { if (ax >= 0) {
for (auto& fft_ax : fft_axes) { for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) { if (static_cast<int>(fft_ax) >= ax) {
fft_ax++; fft_ax++;
} }
if (real_) { if (real_) {

View File

@@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive {
public: public:
explicit FFT( explicit FFT(
Stream stream, Stream stream,
const std::vector<int>& axes, // Note: PocketFFT requires size_t
const std::vector<size_t>& axes,
bool inverse, bool inverse,
bool real) bool real)
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
@@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive {
} }
private: private:
std::vector<int> axes_; std::vector<size_t> axes_;
bool inverse_; bool inverse_;
bool real_; bool real_;
}; };