add fftshift and ifftshift fft helpers (#2135)

* add fftshift and ifftshift fft helpers

* address comments

* axes have to be iterable

* fix fp error in roll + add test

---------

Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
This commit is contained in:
Aashiq Dheeraj
2025-04-30 01:13:45 -04:00
committed by GitHub
parent 7bb063bcb3
commit bb6565ef14
9 changed files with 275 additions and 2 deletions

View File

@@ -459,4 +459,55 @@ void init_fft(nb::module_& parent_module) {
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"fftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::fftshift(a, axes.value(), s);
} else {
return mx::fft::fftshift(a, s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Shift the zero-frequency component to the center of the spectrum.
Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the shift.
If ``None``, shift all axes.
Returns:
array: The shifted array with the same shape as the input.
)pbdoc");
m.def(
"ifftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::ifftshift(a, axes.value(), s);
} else {
return mx::fft::ifftshift(a, s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,
the behavior differs for odd-length axes.
Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the inverse shift.
If ``None``, shift all axes.
Returns:
array: The inverse-shifted array with the same shape as the input.
)pbdoc");
}