This commit is contained in:
paramthakkar123
2025-05-06 09:53:10 +05:30
128 changed files with 2291 additions and 895 deletions

View File

@@ -309,6 +309,7 @@ TEST_CASE("test fft grads") {
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}
<<<<<<< HEAD
TEST_CASE("test stft and istft") {
int n_fft = 4;
int hop_length = 2;
@@ -381,4 +382,62 @@ TEST_CASE("test stft and istft") {
CHECK_EQ(stft_result.shape(1), n_fft);
}
}
}
== == == = TEST_CASE("test fftshift and ifftshift") {
// Test 1D array with even length
auto x = arange(8);
auto y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
// print y
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test 1D array with odd length
x = arange(7);
y = fft::fftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 0, 1, 2, 3})).item<bool>());
// Test 2D array
x = reshape(arange(16), {4, 4});
y = fft::fftshift(x);
auto expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test with specific axes
y = fft::fftshift(x, {0});
expected =
array({8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
y = fft::fftshift(x, {1});
expected =
array({2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test ifftshift (inverse operation)
x = arange(8);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({4, 5, 6, 7, 0, 1, 2, 3})).item<bool>());
// Test ifftshift with odd length (different from fftshift)
x = arange(7);
y = fft::ifftshift(x);
CHECK_EQ(y.shape(), x.shape());
CHECK(array_equal(y, array({3, 4, 5, 6, 0, 1, 2})).item<bool>());
// Test 2D ifftshift
x = reshape(arange(16), {4, 4});
y = fft::ifftshift(x);
expected =
array({10, 11, 8, 9, 14, 15, 12, 13, 2, 3, 0, 1, 6, 7, 4, 5}, {4, 4});
CHECK(array_equal(y, expected).item<bool>());
// Test error cases
CHECK_THROWS_AS(fft::fftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftshift(x, {-5}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument);
CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument);
}
>>>>>>> 5a1a5d5ed16f69af7c3ce56dd94e4502661e1565