Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -162,35 +162,35 @@ TEST_CASE("test fftn") {
x = reshape(arange(20, float32), {5, 4});
y = fft::rfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.shape(), Shape{5, 3});
y = fft::rfftn(x, {1, 0});
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
CHECK_EQ(y.shape(), Shape{3, 4});
x = reshape(arange(20, float32), {5, 4});
y = fft::irfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
CHECK_EQ(y.shape(), Shape{5, 6});
y = fft::irfftn(x, {1, 0});
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
CHECK_EQ(y.shape(), Shape{8, 4});
}
// Check the types of real ffts
{
x = zeros({5, 5}, float32);
auto y = fft::rfft2(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.shape(), Shape{5, 3});
CHECK_EQ(y.dtype(), complex64);
y = fft::rfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.shape(), Shape{5, 3});
CHECK_EQ(y.dtype(), complex64);
x = zeros({5, 5}, complex64);
y = fft::irfft2(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.shape(), Shape{5, 8});
CHECK_EQ(y.dtype(), float32);
y = fft::irfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.shape(), Shape{5, 8});
CHECK_EQ(y.dtype(), float32);
}
}
@@ -199,25 +199,25 @@ TEST_CASE("test fft with provided shape") {
auto x = ones({5, 5});
auto y = fft::fft(x, 7, 0);
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
CHECK_EQ(y.shape(), Shape{7, 5});
y = fft::fft(x, 3, 0);
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
CHECK_EQ(y.shape(), Shape{3, 5});
y = fft::fft(x, 7, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
CHECK_EQ(y.shape(), Shape{5, 7});
y = fft::fft(x, 3, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.shape(), Shape{5, 3});
y = fft::rfft(x, 7, 0);
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
CHECK_EQ(y.shape(), Shape{4, 5});
y = fft::rfft(x, 3, 0);
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
CHECK_EQ(y.shape(), Shape{2, 5});
y = fft::rfft(x, 3, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
CHECK_EQ(y.shape(), Shape{5, 2});
}
TEST_CASE("test fft vmap") {
@@ -288,23 +288,23 @@ TEST_CASE("test fft grads") {
astype(zeros({5, 5}), complex64),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
vjp_out = vjp([](array x) { return fft::ifftn(x); },
astype(zeros({5, 5}), complex64),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
vjp_out = vjp([](array x) { return fft::rfftn(x); },
zeros({5, 9}),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
CHECK_EQ(vjp_out.shape(), Shape{5, 9});
vjp_out = vjp([](array x) { return fft::irfftn(x); },
astype(zeros({5, 5}), complex64),
zeros({5, 8}))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
}