mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -162,35 +162,35 @@ TEST_CASE("test fftn") {
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
y = fft::rfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
|
||||
CHECK_EQ(y.shape(), Shape{3, 4});
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
|
||||
CHECK_EQ(y.shape(), Shape{5, 6});
|
||||
y = fft::irfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
|
||||
CHECK_EQ(y.shape(), Shape{8, 4});
|
||||
}
|
||||
|
||||
// Check the types of real ffts
|
||||
{
|
||||
x = zeros({5, 5}, float32);
|
||||
auto y = fft::rfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
x = zeros({5, 5}, complex64);
|
||||
y = fft::irfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.shape(), Shape{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.shape(), Shape{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
}
|
||||
@@ -199,25 +199,25 @@ TEST_CASE("test fft with provided shape") {
|
||||
auto x = ones({5, 5});
|
||||
|
||||
auto y = fft::fft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
|
||||
CHECK_EQ(y.shape(), Shape{7, 5});
|
||||
|
||||
y = fft::fft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
|
||||
CHECK_EQ(y.shape(), Shape{3, 5});
|
||||
|
||||
y = fft::fft(x, 7, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
|
||||
CHECK_EQ(y.shape(), Shape{5, 7});
|
||||
|
||||
y = fft::fft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.shape(), Shape{5, 3});
|
||||
|
||||
y = fft::rfft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
|
||||
CHECK_EQ(y.shape(), Shape{4, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
CHECK_EQ(y.shape(), Shape{2, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
|
||||
CHECK_EQ(y.shape(), Shape{5, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
@@ -288,23 +288,23 @@ TEST_CASE("test fft grads") {
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::ifftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::rfftn(x); },
|
||||
zeros({5, 9}),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 9});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::irfftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(vjp_out.shape(), Shape{5, 5});
|
||||
}
|
||||
|
Reference in New Issue
Block a user