mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix FFT (PocketFFT requires size_t for axis)
This commit is contained in:
@@ -27,7 +27,7 @@ array fft_impl(
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> valid_axes;
|
std::vector<size_t> valid_axes;
|
||||||
for (int ax : axes) {
|
for (int ax : axes) {
|
||||||
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2156,7 +2156,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
|||||||
auto out_shape = in.shape();
|
auto out_shape = in.shape();
|
||||||
if (ax >= 0) {
|
if (ax >= 0) {
|
||||||
for (auto& fft_ax : fft_axes) {
|
for (auto& fft_ax : fft_axes) {
|
||||||
if (fft_ax >= ax) {
|
if (static_cast<int>(fft_ax) >= ax) {
|
||||||
fft_ax++;
|
fft_ax++;
|
||||||
}
|
}
|
||||||
if (real_) {
|
if (real_) {
|
||||||
|
|||||||
@@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive {
|
|||||||
public:
|
public:
|
||||||
explicit FFT(
|
explicit FFT(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& axes,
|
// Note: PocketFFT requires size_t
|
||||||
|
const std::vector<size_t>& axes,
|
||||||
bool inverse,
|
bool inverse,
|
||||||
bool real)
|
bool real)
|
||||||
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
|
: UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
|
||||||
@@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<size_t> axes_;
|
||||||
bool inverse_;
|
bool inverse_;
|
||||||
bool real_;
|
bool real_;
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user