Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -13,10 +13,10 @@ TEST_CASE("test array basics") {
array x(1.0);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.ndim(), 0);
CHECK_EQ(x.shape(), std::vector<int>{});
CHECK_EQ(x.shape(), Shape{});
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
CHECK_EQ(x.strides(), std::vector<size_t>{});
CHECK_EQ(x.strides(), Strides{});
CHECK_EQ(x.itemsize(), sizeof(float));
CHECK_EQ(x.nbytes(), sizeof(float));
CHECK_EQ(x.dtype(), float32);
@@ -39,12 +39,12 @@ TEST_CASE("test array basics") {
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.ndim(), 1);
CHECK_EQ(x.shape(), std::vector<int>{1});
CHECK_EQ(x.shape(), Shape{1});
CHECK_EQ(x.shape(0), 1);
CHECK_EQ(x.shape(-1), 1);
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
CHECK_EQ(x.strides(), std::vector<size_t>{1});
CHECK_EQ(x.strides(), Strides{1});
CHECK_EQ(x.item<float>(), 1.0);
// Check empty array
@@ -57,7 +57,7 @@ TEST_CASE("test array basics") {
x = array({1.0, 1.0});
CHECK_EQ(x.size(), 2);
CHECK_EQ(x.shape(), std::vector<int>{2});
CHECK_EQ(x.shape(), Shape{2});
CHECK_EQ(x.itemsize(), sizeof(float));
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
@@ -65,9 +65,9 @@ TEST_CASE("test array basics") {
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
x = array({1.0, 1.0, 1.0}, {1, 3});
CHECK(x.size() == 3);
CHECK(x.shape() == std::vector<int>{1, 3});
CHECK(x.strides() == std::vector<size_t>{3, 1});
CHECK_EQ(x.size(), 3);
CHECK_EQ(x.shape(), Shape{1, 3});
CHECK_EQ(x.strides(), Strides{3, 1});
// Test wrong size/shapes throw:
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
@@ -472,7 +472,7 @@ TEST_CASE("test array metadata") {
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {1, 2}, {2, 3});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
CHECK_EQ(y.shape(), Shape{1, 1});
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
@@ -481,7 +481,7 @@ TEST_CASE("test array metadata") {
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
y = slice(x, {0, 0}, {1, 4}, {1, 2});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
CHECK_EQ(y.shape(), Shape{1, 2});
CHECK_EQ(y.flags().contiguous, false);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
@@ -489,7 +489,7 @@ TEST_CASE("test array metadata") {
x = broadcast_to(array(1.0f), {4, 10});
y = slice(x, {0, 0}, {4, 10}, {2, 2});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
CHECK_EQ(y.shape(), Shape{2, 5});
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
@@ -566,8 +566,8 @@ TEST_CASE("test array iteration") {
}
TEST_CASE("test array shared buffer") {
std::vector<int> shape = {2, 2};
int n_elem = shape[0] * shape[1];
Shape shape = {2, 2};
auto n_elem = shape[0] * shape[1];
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
void* buf_b_ptr = buf_b.raw_ptr();