mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix FFT (PocketFFT requires size_t for axis)
This commit is contained in:
@@ -27,7 +27,7 @@ array fft_impl(
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<int> valid_axes;
|
||||
std::vector<size_t> valid_axes;
|
||||
for (int ax : axes) {
|
||||
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
||||
}
|
||||
|
||||
@@ -2156,7 +2156,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
auto out_shape = in.shape();
|
||||
if (ax >= 0) {
|
||||
for (auto& fft_ax : fft_axes) {
|
||||
if (fft_ax >= ax) {
|
||||
if (static_cast<int>(fft_ax) >= ax) {
|
||||
fft_ax++;
|
||||
}
|
||||
if (real_) {
|
||||
|
||||
@@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive {
|
||||
public:
|
||||
explicit FFT(
|
||||
Stream stream,
|
||||
const std::vector<int>& axes,
|
||||
// Note: PocketFFT requires size_t
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real)
|
||||
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
|
||||
@@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> axes_;
|
||||
std::vector<size_t> axes_;
|
||||
bool inverse_;
|
||||
bool real_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user