mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 13:07:29 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -13,10 +13,10 @@ TEST_CASE("test array basics") {
|
||||
array x(1.0);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 0);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_EQ(x.shape(), Shape{});
|
||||
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{});
|
||||
CHECK_EQ(x.strides(), Strides{});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), sizeof(float));
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
@@ -39,12 +39,12 @@ TEST_CASE("test array basics") {
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 1);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(x.shape(), Shape{1});
|
||||
CHECK_EQ(x.shape(0), 1);
|
||||
CHECK_EQ(x.shape(-1), 1);
|
||||
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{1});
|
||||
CHECK_EQ(x.strides(), Strides{1});
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
|
||||
// Check empty array
|
||||
@@ -57,7 +57,7 @@ TEST_CASE("test array basics") {
|
||||
|
||||
x = array({1.0, 1.0});
|
||||
CHECK_EQ(x.size(), 2);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2});
|
||||
CHECK_EQ(x.shape(), Shape{2});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
|
||||
|
||||
@@ -65,9 +65,9 @@ TEST_CASE("test array basics") {
|
||||
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||
|
||||
x = array({1.0, 1.0, 1.0}, {1, 3});
|
||||
CHECK(x.size() == 3);
|
||||
CHECK(x.shape() == std::vector<int>{1, 3});
|
||||
CHECK(x.strides() == std::vector<size_t>{3, 1});
|
||||
CHECK_EQ(x.size(), 3);
|
||||
CHECK_EQ(x.shape(), Shape{1, 3});
|
||||
CHECK_EQ(x.strides(), Strides{3, 1});
|
||||
|
||||
// Test wrong size/shapes throw:
|
||||
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
|
||||
@@ -472,7 +472,7 @@ TEST_CASE("test array metadata") {
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 2}, {2, 3});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(y.shape(), Shape{1, 1});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
@@ -481,7 +481,7 @@ TEST_CASE("test array metadata") {
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
|
||||
y = slice(x, {0, 0}, {1, 4}, {1, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
|
||||
CHECK_EQ(y.shape(), Shape{1, 2});
|
||||
CHECK_EQ(y.flags().contiguous, false);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
@@ -489,7 +489,7 @@ TEST_CASE("test array metadata") {
|
||||
x = broadcast_to(array(1.0f), {4, 10});
|
||||
y = slice(x, {0, 0}, {4, 10}, {2, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
CHECK_EQ(y.shape(), Shape{2, 5});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
@@ -566,8 +566,8 @@ TEST_CASE("test array iteration") {
|
||||
}
|
||||
|
||||
TEST_CASE("test array shared buffer") {
|
||||
std::vector<int> shape = {2, 2};
|
||||
int n_elem = shape[0] * shape[1];
|
||||
Shape shape = {2, 2};
|
||||
auto n_elem = shape[0] * shape[1];
|
||||
|
||||
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
|
||||
void* buf_b_ptr = buf_b.raw_ptr();
|
||||
|
Reference in New Issue
Block a user