address comments

This commit is contained in:
Aashiq Dheeraj 2025-04-29 20:42:27 -04:00
parent 00e43d18ed
commit 7c9746a0bc
2 changed files with 6 additions and 8 deletions

View File

@ -207,7 +207,7 @@ array fftshift(
<< " dimensions."; << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Match PyTorch's implementation // Match NumPy's implementation
shifts.push_back(a.shape(axis) / 2); shifts.push_back(a.shape(axis) / 2);
} }
@ -232,9 +232,9 @@ array ifftshift(
<< " dimensions."; << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Match PyTorch's implementation // Match NumPy's implementation
int size = a.shape(axis); int size = a.shape(axis);
shifts.push_back((size + 1) / 2); shifts.push_back(-(size / 2));
} }
return roll(a, shifts, axes, s); return roll(a, shifts, axes, s);

View File

@ -479,8 +479,7 @@ void init_fft(nb::module_& parent_module) {
Args: Args:
a (array): The input array. a (array): The input array.
axes (list(int), optional): Axes over which to perform the shift. axes (list(int), optional): Axes over which to perform the shift.
If None, shift all axes. Each axis can be negative and will be If ``None``, shift all axes.
converted to a positive axis using the same rules as NumPy.
Returns: Returns:
array: The shifted array with the same shape as the input. array: The shifted array with the same shape as the input.
@ -500,14 +499,13 @@ void init_fft(nb::module_& parent_module) {
"axes"_a = nb::none(), "axes"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
R"pbdoc( R"pbdoc(
The inverse of fftshift. While identical to fftshift for even-length axes, The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,
the behavior differs for odd-length axes. the behavior differs for odd-length axes.
Args: Args:
a (array): The input array. a (array): The input array.
axes (list(int), optional): Axes over which to perform the inverse shift. axes (list(int), optional): Axes over which to perform the inverse shift.
If None, shift all axes. Each axis can be negative and will be If ``None``, shift all axes.
converted to a positive axis using the same rules as NumPy.
Returns: Returns:
array: The inverse-shifted array with the same shape as the input. array: The inverse-shifted array with the same shape as the input.