add fftshift and ifftshift fft helpers

This commit is contained in:
Aashiq Dheeraj 2025-04-29 00:28:01 -04:00 committed by Aashiq Dheeraj
parent 99b9868859
commit 00e43d18ed
6 changed files with 264 additions and 0 deletions

View File

@ -20,3 +20,5 @@ FFT
irfft2 irfft2
rfftn rfftn
irfftn irfftn
fftshift
ifftshift

View File

@ -184,8 +184,79 @@ array irfftn(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, true, true, s); return fft_impl(a, axes, true, true, s);
} }
array irfftn(const array& a, StreamOrDevice s /* = {} */) { array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s); return fft_impl(a, true, true, s);
} }
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match PyTorch's implementation
shifts.push_back(a.shape(axis) / 2);
}
return roll(a, shifts, axes, s);
}
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
if (axes.empty()) {
return a;
}
Shape shifts;
for (int ax : axes) {
// Convert negative axes to positive
int axis = ax < 0 ? ax + a.ndim() : ax;
if (axis < 0 || axis >= a.ndim()) {
std::ostringstream msg;
msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Match PyTorch's implementation
int size = a.shape(axis);
shifts.push_back((size + 1) / 2);
}
return roll(a, shifts, axes, s);
}
// Default versions that operate on all axes
array fftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return fftshift(a, axes, s);
}
array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() < 1) {
return a;
}
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return ifftshift(a, axes, s);
}
} // namespace mlx::core::fft } // namespace mlx::core::fft

View File

@ -145,5 +145,23 @@ inline array irfft2(
StreamOrDevice s = {}) { StreamOrDevice s = {}) {
return irfftn(a, axes, s); return irfftn(a, axes, s);
} }
/** Shift the zero-frequency component to the center of the spectrum. */
array fftshift(const array& a, StreamOrDevice s = {});
/** Shift the zero-frequency component to the center of the spectrum along
* specified axes. */
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** The inverse of fftshift. */
array ifftshift(const array& a, StreamOrDevice s = {});
/** The inverse of fftshift along specified axes. */
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
} // namespace mlx::core::fft } // namespace mlx::core::fft

View File

@ -459,4 +459,57 @@ void init_fft(nb::module_& parent_module) {
Returns: Returns:
array: The real array containing the inverse of :func:`rfftn`. array: The real array containing the inverse of :func:`rfftn`.
)pbdoc"); )pbdoc");
m.def(
"fftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::fftshift(a, axes.value(), s);
} else {
return mx::fft::fftshift(a, s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
Shift the zero-frequency component to the center of the spectrum.
Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the shift.
If None, shift all axes. Each axis can be negative and will be
converted to a positive axis using the same rules as NumPy.
Returns:
array: The shifted array with the same shape as the input.
)pbdoc");
m.def(
"ifftshift",
[](const mx::array& a,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value()) {
return mx::fft::ifftshift(a, axes.value(), s);
} else {
return mx::fft::ifftshift(a, s);
}
},
"a"_a,
"axes"_a = nb::none(),
"stream"_a = nb::none(),
R"pbdoc(
The inverse of fftshift. While identical to fftshift for even-length axes,
the behavior differs for odd-length axes.
Args:
a (array): The input array.
axes (list(int), optional): Axes over which to perform the inverse shift.
If None, shift all axes. Each axis can be negative and will be
converted to a positive axis using the same rules as NumPy.
Returns:
array: The inverse-shifted array with the same shape as the input.
)pbdoc");
} }

View File

@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.fft.irfftn(x) mx.fft.irfftn(x)
def test_fftshift(self):
# Test 1D arrays
r = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1])
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1])
# Test with negative axes
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1])
# Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
# Test with complex input
r = np.random.rand(8, 8).astype(np.float32)
i = np.random.rand(8, 8).astype(np.float32)
c = r + 1j * i
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c)
def test_ifftshift(self):
# Test 1D arrays
r = np.random.rand(100).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
# Test with specific axis
r = np.random.rand(4, 6).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=0)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=1)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0, 1))
# Test with negative axes
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=-1)
# Test with odd lengths
r = np.random.rand(5, 7).astype(np.float32)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=(0,))
# Test with complex input
r = np.random.rand(8, 8).astype(np.float32)
i = np.random.rand(8, 8).astype(np.float32)
c = r + 1j * i
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c)
def test_fftshift_errors(self):
# Test invalid axes
x = mx.array(np.random.rand(4, 4).astype(np.float32))
with self.assertRaises(ValueError):
mx.fft.fftshift(x, axes=[2])
with self.assertRaises(ValueError):
mx.fft.fftshift(x, axes=[-3])
# Test empty array
x = mx.array([])
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -308,3 +308,61 @@ TEST_CASE("test fft grads") {
.second; .second;
CHECK_EQ(vjp_out.shape(), Shape{5, 5}); CHECK_EQ(vjp_out.shape(), Shape{5, 5});
} }
TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);
auto y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
// print y
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test 1D array with odd length
x = arange(7);
y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());
// Test 2D array
x = reshape(arange(16), {4, 4});
y = fft::fftshift(x);
auto expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test with specific axes
y = fft::fftshift(x, {0});
expected =
array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
y = fft::fftshift(x, {1});
expected =
array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test ifftshift (inverse operation)
x = arange(8);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test ifftshift with odd length (different from fftshift)
x = arange(7);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());
// Test 2D ifftshift
x = reshape(arange(16), {4, 4});
y = fft::ifftshift(x);
expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test error cases
CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);
}