From 7c9746a0bc8d443916f290787c6e9d271c41f65c Mon Sep 17 00:00:00 2001 From: Aashiq Dheeraj Date: Tue, 29 Apr 2025 20:42:27 -0400 Subject: [PATCH] address comments --- mlx/fft.cpp | 6 +++--- python/src/fft.cpp | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 355f3acd6..6510faec1 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -207,7 +207,7 @@ array fftshift( << " dimensions."; throw std::invalid_argument(msg.str()); } - // Match PyTorch's implementation + // Match NumPy's implementation shifts.push_back(a.shape(axis) / 2); } @@ -232,9 +232,9 @@ array ifftshift( << " dimensions."; throw std::invalid_argument(msg.str()); } - // Match PyTorch's implementation + // Match NumPy's implementation int size = a.shape(axis); - shifts.push_back((size + 1) / 2); + shifts.push_back(-(size / 2)); } return roll(a, shifts, axes, s); diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 2d3c47584..026f8139d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -479,8 +479,7 @@ void init_fft(nb::module_& parent_module) { Args: a (array): The input array. axes (list(int), optional): Axes over which to perform the shift. - If None, shift all axes. Each axis can be negative and will be - converted to a positive axis using the same rules as NumPy. + If ``None``, shift all axes. Returns: array: The shifted array with the same shape as the input. @@ -500,14 +499,13 @@ void init_fft(nb::module_& parent_module) { "axes"_a = nb::none(), "stream"_a = nb::none(), R"pbdoc( - The inverse of fftshift. While identical to fftshift for even-length axes, + The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, the behavior differs for odd-length axes. Args: a (array): The input array. axes (list(int), optional): Axes over which to perform the inverse shift. - If None, shift all axes. Each axis can be negative and will be - converted to a positive axis using the same rules as NumPy. + If ``None``, shift all axes. Returns: array: The inverse-shifted array with the same shape as the input.