mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
add fftshift and ifftshift fft helpers (#2135)
* add fftshift and ifftshift fft helpers * address comments * axes have to be iterable * fix fp error in roll + add test --------- Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
This commit is contained in:
parent
7bb063bcb3
commit
bb6565ef14
@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
71
mlx/fft.cpp
71
mlx/fft.cpp
@ -184,8 +184,79 @@ array irfftn(
|
|||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fft_impl(a, axes, true, true, s);
|
return fft_impl(a, axes, true, true, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
return fft_impl(a, true, true, s);
|
return fft_impl(a, true, true, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array fftshift(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (axes.empty()) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape shifts;
|
||||||
|
for (int ax : axes) {
|
||||||
|
// Convert negative axes to positive
|
||||||
|
int axis = ax < 0 ? ax + a.ndim() : ax;
|
||||||
|
if (axis < 0 || axis >= a.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
// Match NumPy's implementation
|
||||||
|
shifts.push_back(a.shape(axis) / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return roll(a, shifts, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array ifftshift(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (axes.empty()) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape shifts;
|
||||||
|
for (int ax : axes) {
|
||||||
|
// Convert negative axes to positive
|
||||||
|
int axis = ax < 0 ? ax + a.ndim() : ax;
|
||||||
|
if (axis < 0 || axis >= a.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim()
|
||||||
|
<< " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
// Match NumPy's implementation
|
||||||
|
int size = a.shape(axis);
|
||||||
|
shifts.push_back(-(size / 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return roll(a, shifts, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default versions that operate on all axes
|
||||||
|
array fftshift(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() < 1) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
std::vector<int> axes(a.ndim());
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
return fftshift(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
if (a.ndim() < 1) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
std::vector<int> axes(a.ndim());
|
||||||
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
|
return ifftshift(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fft
|
} // namespace mlx::core::fft
|
||||||
|
18
mlx/fft.h
18
mlx/fft.h
@ -145,5 +145,23 @@ inline array irfft2(
|
|||||||
StreamOrDevice s = {}) {
|
StreamOrDevice s = {}) {
|
||||||
return irfftn(a, axes, s);
|
return irfftn(a, axes, s);
|
||||||
}
|
}
|
||||||
|
/** Shift the zero-frequency component to the center of the spectrum. */
|
||||||
|
array fftshift(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Shift the zero-frequency component to the center of the spectrum along
|
||||||
|
* specified axes. */
|
||||||
|
array fftshift(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The inverse of fftshift. */
|
||||||
|
array ifftshift(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** The inverse of fftshift along specified axes. */
|
||||||
|
array ifftshift(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::fft
|
} // namespace mlx::core::fft
|
||||||
|
@ -5025,8 +5025,11 @@ array roll(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto sh = shift[i];
|
auto sh = shift[i];
|
||||||
auto split_index =
|
auto size = a.shape(ax);
|
||||||
(sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax);
|
if (size == 0) {
|
||||||
|
continue; // skip rolling this axis if it has size 0
|
||||||
|
}
|
||||||
|
auto split_index = (sh < 0) ? (-sh) % size : size - sh % size;
|
||||||
|
|
||||||
auto parts = split(result, Shape{split_index}, ax, s);
|
auto parts = split(result, Shape{split_index}, ax, s);
|
||||||
std::swap(parts[0], parts[1]);
|
std::swap(parts[0], parts[1]);
|
||||||
|
@ -459,4 +459,55 @@ void init_fft(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The real array containing the inverse of :func:`rfftn`.
|
array: The real array containing the inverse of :func:`rfftn`.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"fftshift",
|
||||||
|
[](const mx::array& a,
|
||||||
|
const std::optional<std::vector<int>>& axes,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
if (axes.has_value()) {
|
||||||
|
return mx::fft::fftshift(a, axes.value(), s);
|
||||||
|
} else {
|
||||||
|
return mx::fft::fftshift(a, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"axes"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
R"pbdoc(
|
||||||
|
Shift the zero-frequency component to the center of the spectrum.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): The input array.
|
||||||
|
axes (list(int), optional): Axes over which to perform the shift.
|
||||||
|
If ``None``, shift all axes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The shifted array with the same shape as the input.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"ifftshift",
|
||||||
|
[](const mx::array& a,
|
||||||
|
const std::optional<std::vector<int>>& axes,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
if (axes.has_value()) {
|
||||||
|
return mx::fft::ifftshift(a, axes.value(), s);
|
||||||
|
} else {
|
||||||
|
return mx::fft::ifftshift(a, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"axes"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
R"pbdoc(
|
||||||
|
The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,
|
||||||
|
the behavior differs for odd-length axes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): The input array.
|
||||||
|
axes (list(int), optional): Axes over which to perform the inverse shift.
|
||||||
|
If ``None``, shift all axes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The inverse-shifted array with the same shape as the input.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -199,6 +199,68 @@ class TestFFT(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.fft.irfftn(x)
|
mx.fft.irfftn(x)
|
||||||
|
|
||||||
|
def test_fftshift(self):
|
||||||
|
# Test 1D arrays
|
||||||
|
r = np.random.rand(100).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
|
||||||
|
|
||||||
|
# Test with specific axis
|
||||||
|
r = np.random.rand(4, 6).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[1])
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0, 1])
|
||||||
|
|
||||||
|
# Test with negative axes
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[-1])
|
||||||
|
|
||||||
|
# Test with odd lengths
|
||||||
|
r = np.random.rand(5, 7).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r)
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, r, axes=[0])
|
||||||
|
|
||||||
|
# Test with complex input
|
||||||
|
r = np.random.rand(8, 8).astype(np.float32)
|
||||||
|
i = np.random.rand(8, 8).astype(np.float32)
|
||||||
|
c = r + 1j * i
|
||||||
|
self.check_mx_np(mx.fft.fftshift, np.fft.fftshift, c)
|
||||||
|
|
||||||
|
def test_ifftshift(self):
|
||||||
|
# Test 1D arrays
|
||||||
|
r = np.random.rand(100).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
|
||||||
|
|
||||||
|
# Test with specific axis
|
||||||
|
r = np.random.rand(4, 6).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[1])
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0, 1])
|
||||||
|
|
||||||
|
# Test with negative axes
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[-1])
|
||||||
|
|
||||||
|
# Test with odd lengths
|
||||||
|
r = np.random.rand(5, 7).astype(np.float32)
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r)
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, r, axes=[0])
|
||||||
|
|
||||||
|
# Test with complex input
|
||||||
|
r = np.random.rand(8, 8).astype(np.float32)
|
||||||
|
i = np.random.rand(8, 8).astype(np.float32)
|
||||||
|
c = r + 1j * i
|
||||||
|
self.check_mx_np(mx.fft.ifftshift, np.fft.ifftshift, c)
|
||||||
|
|
||||||
|
def test_fftshift_errors(self):
|
||||||
|
# Test invalid axes
|
||||||
|
x = mx.array(np.random.rand(4, 4).astype(np.float32))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.fft.fftshift(x, axes=[2])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.fft.fftshift(x, axes=[-3])
|
||||||
|
|
||||||
|
# Test empty array
|
||||||
|
x = mx.array([])
|
||||||
|
self.assertTrue(mx.array_equal(mx.fft.fftshift(x), x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2961,6 +2961,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
y2 = mx.roll(x, s, a)
|
y2 = mx.roll(x, s, a)
|
||||||
self.assertTrue(mx.array_equal(y1, y2).item())
|
self.assertTrue(mx.array_equal(y1, y2).item())
|
||||||
|
|
||||||
|
def test_roll_errors(self):
|
||||||
|
x = mx.array([])
|
||||||
|
result = mx.roll(x, [0], [0])
|
||||||
|
self.assertTrue(mx.array_equal(result, x))
|
||||||
|
|
||||||
def test_real_imag(self):
|
def test_real_imag(self):
|
||||||
x = mx.random.uniform(shape=(4, 4))
|
x = mx.random.uniform(shape=(4, 4))
|
||||||
out = mx.real(x)
|
out = mx.real(x)
|
||||||
|
@ -308,3 +308,61 @@ TEST_CASE("test fft grads") {
|
|||||||
.second;
|
.second;
|
||||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test fftshift and ifftshift") {
|
||||||
|
// Test 1D array with even length
|
||||||
|
auto x = arange(8);
|
||||||
|
auto y = fft::fftshift(x);
|
||||||
|
CHECK_EQ(y.shape(), x.shape());
|
||||||
|
// print y
|
||||||
|
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
|
||||||
|
|
||||||
|
// Test 1D array with odd length
|
||||||
|
x = arange(7);
|
||||||
|
y = fft::fftshift(x);
|
||||||
|
CHECK_EQ(y.shape(), x.shape());
|
||||||
|
CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());
|
||||||
|
|
||||||
|
// Test 2D array
|
||||||
|
x = reshape(arange(16), {4, 4});
|
||||||
|
y = fft::fftshift(x);
|
||||||
|
auto expected =
|
||||||
|
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
|
||||||
|
// Test with specific axes
|
||||||
|
y = fft::fftshift(x, {0});
|
||||||
|
expected =
|
||||||
|
array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
|
||||||
|
y = fft::fftshift(x, {1});
|
||||||
|
expected =
|
||||||
|
array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
|
||||||
|
// Test ifftshift (inverse operation)
|
||||||
|
x = arange(8);
|
||||||
|
y = fft::ifftshift(x);
|
||||||
|
CHECK_EQ(y.shape(), x.shape());
|
||||||
|
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
|
||||||
|
|
||||||
|
// Test ifftshift with odd length (different from fftshift)
|
||||||
|
x = arange(7);
|
||||||
|
y = fft::ifftshift(x);
|
||||||
|
CHECK_EQ(y.shape(), x.shape());
|
||||||
|
CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());
|
||||||
|
|
||||||
|
// Test 2D ifftshift
|
||||||
|
x = reshape(arange(16), {4, 4});
|
||||||
|
y = fft::ifftshift(x);
|
||||||
|
expected =
|
||||||
|
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
|
||||||
|
CHECK(array_equal(y, expected).item<bool>());
|
||||||
|
|
||||||
|
// Test error cases
|
||||||
|
CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);
|
||||||
|
}
|
||||||
|
@ -3859,6 +3859,9 @@ TEST_CASE("test roll") {
|
|||||||
y = roll(x, {1, 2}, {0, 1});
|
y = roll(x, {1, 2}, {0, 1});
|
||||||
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
|
||||||
|
y = roll(array({}), 0, 0);
|
||||||
|
CHECK(array_equal(y, array({})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test contiguous") {
|
TEST_CASE("test contiguous") {
|
||||||
|
Loading…
Reference in New Issue
Block a user