diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 33f5e763a..69fcf6479 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -27,7 +27,7 @@ array fft_impl( return a; } - std::vector valid_axes; + std::vector valid_axes; for (int ax : axes) { valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax); } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 152487de6..0faa9f407 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2156,7 +2156,7 @@ std::pair, std::vector> FFT::vmap( auto out_shape = in.shape(); if (ax >= 0) { for (auto& fft_ax : fft_axes) { - if (fft_ax >= ax) { + if (static_cast(fft_ax) >= ax) { fft_ax++; } if (real_) { diff --git a/mlx/primitives.h b/mlx/primitives.h index a37124db9..9ff82ca8a 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1071,7 +1071,8 @@ class FFT : public UnaryPrimitive { public: explicit FFT( Stream stream, - const std::vector& axes, + // Note: PocketFFT requires size_t + const std::vector& axes, bool inverse, bool real) : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} @@ -1089,7 +1090,7 @@ class FFT : public UnaryPrimitive { } private: - std::vector axes_; + std::vector axes_; bool inverse_; bool real_; };