diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 9e4be084b..36d9d7838 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,3 +20,5 @@ FFT irfft2 rfftn irfftn + fftshift + ifftshift diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 02878af9c..6510faec1 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -184,8 +184,79 @@ array irfftn( StreamOrDevice s /* = {} */) { return fft_impl(a, axes, true, true, s); } + array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + shifts.push_back(a.shape(axis) / 2); + } + + return roll(a, shifts, axes, s); +} + +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s /* = {} */) { + if (axes.empty()) { + return a; + } + + Shape shifts; + for (int ax : axes) { + // Convert negative axes to positive + int axis = ax < 0 ? ax + a.ndim() : ax; + if (axis < 0 || axis >= a.ndim()) { + std::ostringstream msg; + msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + // Match NumPy's implementation + int size = a.shape(axis); + shifts.push_back(-(size / 2)); + } + + return roll(a, shifts, axes, s); +} + +// Default versions that operate on all axes +array fftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return fftshift(a, axes, s); +} + +array ifftshift(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() < 1) { + return a; + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return ifftshift(a, axes, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index 2f02da73b..163e06b80 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -145,5 +145,23 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, s); } +/** Shift the zero-frequency component to the center of the spectrum. */ +array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); } // namespace mlx::core::fft diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f8308c2d5..e7abe12db 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -5025,8 +5025,11 @@ array roll( } auto sh = shift[i]; - auto split_index = - (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); + auto size = a.shape(ax); + if (size == 0) { + continue; // skip rolling this axis if it has size 0 + } + auto split_index = (sh < 0) ? (-sh) % size : size - sh % size; auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 5ad4702e2..026f8139d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -459,4 +459,55 @@ void init_fft(nb::module_& parent_module) { Returns: array: The real array containing the inverse of :func:`rfftn`. )pbdoc"); + m.def( + "fftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::fftshift(a, axes.value(), s); + } else { + return mx::fft::fftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Shift the zero-frequency component to the center of the spectrum. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the shift. + If ``None``, shift all axes. + + Returns: + array: The shifted array with the same shape as the input. + )pbdoc"); + m.def( + "ifftshift", + [](const mx::array& a, + const std::optional>& axes, + mx::StreamOrDevice s) { + if (axes.has_value()) { + return mx::fft::ifftshift(a, axes.value(), s); + } else { + return mx::fft::ifftshift(a, s); + } + }, + "a"_a, + "axes"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, + the behavior differs for odd-length axes. + + Args: + a (array): The input array. + axes (list(int), optional): Axes over which to perform the inverse shift. + If ``None``, shift all axes. + + Returns: + array: The inverse-shifted array with the same shape as the input. + )pbdoc"); } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index c887cd968..f644944c7 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.fft.irfftn(x) + def test_fftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1]) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r) + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c) + + def test_ifftshift(self): + # Test 1D arrays + r = np.random.rand(100).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + + # Test with specific axis + r = np.random.rand(4, 6).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1]) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1]) + + # Test with negative axes + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1]) + + # Test with odd lengths + r = np.random.rand(5, 7).astype(np.float32) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r) + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0]) + + # Test with complex input + r = np.random.rand(8, 8).astype(np.float32) + i = np.random.rand(8, 8).astype(np.float32) + c = r + 1j * i + self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c) + + def test_fftshift_errors(self): + # Test invalid axes + x = mx.array(np.random.rand(4, 4).astype(np.float32)) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[2]) + with self.assertRaises(ValueError): + mx.fft.fftshift(x, axes=[-3]) + + # Test empty array + x = mx.array([]) + self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 47fec3167..d840eac7d 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2961,6 +2961,11 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_roll_errors(self): + x = mx.array([]) + result = mx.roll(x, [0], [0]) + self.assertTrue(mx.array_equal(result, x)) + def test_real_imag(self): x = mx.random.uniform(shape=(4, 4)) out = mx.real(x) diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index c04dda1d5..0db3999c8 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -308,3 +308,61 @@ TEST_CASE("test fft grads") { .second; CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } + +TEST_CASE("test fftshift and ifftshift") { + // Test 1D array with even length + auto x = arange(8); + auto y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + // print y + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test 1D array with odd length + x = arange(7); + y = fft::fftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item()); + + // Test 2D array + x = reshape(arange(16), {4, 4}); + y = fft::fftshift(x); + auto expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test with specific axes + y = fft::fftshift(x, {0}); + expected = + array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + y = fft::fftshift(x, {1}); + expected = + array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test ifftshift (inverse operation) + x = arange(8); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item()); + + // Test ifftshift with odd length (different from fftshift) + x = arange(7); + y = fft::ifftshift(x); + CHECK_EQ(y.shape(), x.shape()); + CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item()); + + // Test 2D ifftshift + x = reshape(arange(16), {4, 4}); + y = fft::ifftshift(x); + expected = + array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4}); + CHECK(array_equal(y, expected).item()); + + // Test error cases + CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); + CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); +} diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c4f319d46..5e2bae5a0 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3859,6 +3859,9 @@ TEST_CASE("test roll") { y = roll(x, {1, 2}, {0, 1}); CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) .item()); + + y = roll(array({}), 0, 0); + CHECK(array_equal(y, array({})).item()); } TEST_CASE("test contiguous") {