mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:11:44 +08:00
add fftshift and ifftshift fft helpers (#2135)
* add fftshift and ifftshift fft helpers * address comments * axes have to be iterable * fix fp error in roll + add test --------- Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
This commit is contained in:
@@ -308,3 +308,61 @@ TEST_CASE("test fft grads") {
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
}
|
||||
|
||||
TEST_CASE("test fftshift and ifftshift") {
|
||||
// Test 1D array with even length
|
||||
auto x = arange(8);
|
||||
auto y = fft::fftshift(x);
|
||||
CHECK_EQ(y.shape(), x.shape());
|
||||
// print y
|
||||
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
|
||||
|
||||
// Test 1D array with odd length
|
||||
x = arange(7);
|
||||
y = fft::fftshift(x);
|
||||
CHECK_EQ(y.shape(), x.shape());
|
||||
CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());
|
||||
|
||||
// Test 2D array
|
||||
x = reshape(arange(16), {4, 4});
|
||||
y = fft::fftshift(x);
|
||||
auto expected =
|
||||
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
// Test with specific axes
|
||||
y = fft::fftshift(x, {0});
|
||||
expected =
|
||||
array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
y = fft::fftshift(x, {1});
|
||||
expected =
|
||||
array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
// Test ifftshift (inverse operation)
|
||||
x = arange(8);
|
||||
y = fft::ifftshift(x);
|
||||
CHECK_EQ(y.shape(), x.shape());
|
||||
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
|
||||
|
||||
// Test ifftshift with odd length (different from fftshift)
|
||||
x = arange(7);
|
||||
y = fft::ifftshift(x);
|
||||
CHECK_EQ(y.shape(), x.shape());
|
||||
CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());
|
||||
|
||||
// Test 2D ifftshift
|
||||
x = reshape(arange(16), {4, 4});
|
||||
y = fft::ifftshift(x);
|
||||
expected =
|
||||
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
|
||||
CHECK(array_equal(y, expected).item<bool>());
|
||||
|
||||
// Test error cases
|
||||
CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);
|
||||
}
|
||||
|
Reference in New Issue
Block a user