Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -15,13 +15,13 @@ using namespace mlx::core;
TEST_CASE("test copy") {
array x(1.0);
auto y = copy(x);
CHECK_EQ(y.shape(), std::vector<int>{});
CHECK_EQ(y.shape(), Shape{});
CHECK_NE(y.id(), x.id());
CHECK_EQ(y.item<float>(), 1.0f);
x = array({1, 2}, {2, 1});
y = copy(x);
CHECK_EQ(y.shape(), std::vector<int>{2, 1});
CHECK_EQ(y.shape(), Shape{2, 1});
CHECK_EQ(y.dtype(), int32);
CHECK_NE(y.id(), x.id());
CHECK(array_equal(y, x).item<bool>());
@@ -29,37 +29,37 @@ TEST_CASE("test copy") {
TEST_CASE("test reshape") {
array x(1.0);
CHECK_EQ(reshape(x, {}).shape(), std::vector<int>{});
CHECK_EQ(reshape(x, {}).shape(), Shape{});
CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument);
auto y = reshape(x, {1, 1, 1});
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(y.shape(), Shape{1, 1, 1});
y = reshape(x, {-1, 1, 1});
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(y.shape(), Shape{1, 1, 1});
y = reshape(x, {1, 1, -1});
CHECK_EQ(y.shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(y.shape(), Shape{1, 1, 1});
CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument);
CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument);
x = zeros({2, 2, 2});
y = reshape(x, {8});
CHECK_EQ(y.shape(), std::vector<int>{8});
CHECK_EQ(y.shape(), Shape{8});
CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument);
y = reshape(x, {-1});
CHECK_EQ(y.shape(), std::vector<int>{8});
CHECK_EQ(y.shape(), Shape{8});
y = reshape(x, {-1, 2});
CHECK_EQ(y.shape(), std::vector<int>{4, 2});
CHECK_EQ(y.shape(), Shape{4, 2});
CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument);
// Works with empty array
x = array({});
y = reshape(x, {0, 0, 0});
CHECK_EQ(y.shape(), std::vector<int>{0, 0, 0});
CHECK_EQ(y.shape(), Shape{0, 0, 0});
y.eval();
CHECK_EQ(y.size(), 0);
CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument);
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
y = reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
CHECK_EQ(y.shape(), Shape{1, 5, 0});
// Check that reshaping a transposed array doesn't result in a copy
x = reshape(arange(64), {2, 4, 8});
@@ -138,15 +138,15 @@ TEST_CASE("test reshape") {
TEST_CASE("test flatten") {
array x = zeros({2, 3, 4});
CHECK_EQ(flatten(x).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x).shape(), Shape({2 * 3 * 4}));
CHECK_EQ(flatten(x, 1, 1).shape(), std::vector<int>({2, 3, 4}));
CHECK_EQ(flatten(x, 1, 2).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, 3).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -2, -1).shape(), std::vector<int>({2, 3 * 4}));
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, -4, -1).shape(), std::vector<int>({2 * 3 * 4}));
CHECK_EQ(flatten(x, 1, 1).shape(), Shape({2, 3, 4}));
CHECK_EQ(flatten(x, 1, 2).shape(), Shape({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, 3).shape(), Shape({2, 3 * 4}));
CHECK_EQ(flatten(x, 1, -1).shape(), Shape({2, 3 * 4}));
CHECK_EQ(flatten(x, -2, -1).shape(), Shape({2, 3 * 4}));
CHECK_EQ(flatten(x, -3, -1).shape(), Shape({2 * 3 * 4}));
CHECK_EQ(flatten(x, -4, -1).shape(), Shape({2 * 3 * 4}));
// Check start > end throws
CHECK_THROWS(flatten(x, 2, 1));
@@ -159,17 +159,17 @@ TEST_CASE("test flatten") {
// Check scalar flattens to 1D
x = array(1);
CHECK_EQ(flatten(x, -3, -1).shape(), std::vector<int>({1}));
CHECK_EQ(flatten(x, 0, 0).shape(), std::vector<int>({1}));
CHECK_EQ(flatten(x, -3, -1).shape(), Shape({1}));
CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
}
TEST_CASE("test squeeze and expand") {
array x = zeros({2, 1, 2, 1, 2, 1});
CHECK_EQ(squeeze(x).shape(), std::vector<int>{2, 2, 2});
CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), std::vector<int>{2, 2, 2});
CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), std::vector<int>{2, 2, 2});
CHECK_EQ(squeeze(x, 1).shape(), std::vector<int>{2, 2, 1, 2, 1});
CHECK_EQ(squeeze(x, -1).shape(), std::vector<int>{2, 1, 2, 1, 2});
CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});
CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), Shape{2, 2, 2});
CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), Shape{2, 2, 2});
CHECK_EQ(squeeze(x, 1).shape(), Shape{2, 2, 1, 2, 1});
CHECK_EQ(squeeze(x, -1).shape(), Shape{2, 1, 2, 1, 2});
CHECK_THROWS(squeeze(x, 0));
CHECK_THROWS(squeeze(x, 2));
@@ -177,13 +177,13 @@ TEST_CASE("test squeeze and expand") {
CHECK_THROWS(squeeze(x, {1, 3, -3}));
x = zeros({2, 2});
CHECK_EQ(expand_dims(x, 0).shape(), std::vector<int>{1, 2, 2});
CHECK_EQ(expand_dims(x, -1).shape(), std::vector<int>{2, 2, 1});
CHECK_EQ(expand_dims(x, 1).shape(), std::vector<int>{2, 1, 2});
CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), std::vector<int>{1, 1, 1, 2, 2});
CHECK_EQ(expand_dims(x, 0).shape(), Shape{1, 2, 2});
CHECK_EQ(expand_dims(x, -1).shape(), Shape{2, 2, 1});
CHECK_EQ(expand_dims(x, 1).shape(), Shape{2, 1, 2});
CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), Shape{1, 1, 1, 2, 2});
CHECK_EQ(
expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(),
std::vector<int>{1, 1, 1, 2, 2, 1, 1, 1});
Shape{1, 1, 1, 2, 2, 1, 1, 1});
CHECK_THROWS(expand_dims(x, 3));
CHECK_THROWS(expand_dims(x, -4));
@@ -210,7 +210,7 @@ TEST_CASE("test slice") {
out = slice(x, {1}, {0});
eval(out);
CHECK_EQ(out.shape(), std::vector<int>{0});
CHECK_EQ(out.shape(), Shape{0});
out = slice(x, {0}, {1}, {1});
CHECK_EQ(out.item<int>(), 3);
@@ -353,7 +353,7 @@ TEST_CASE("test split") {
out = split(x, 3, -1);
CHECK_EQ(out.size(), 3);
for (auto i = 0; i < 3; ++i) {
CHECK_EQ(out[i].shape(), std::vector<int>{1});
CHECK_EQ(out[i].shape(), Shape{1});
CHECK_EQ(out[i].dtype(), int32);
CHECK_EQ(out[i].item<int>(), i);
}
@@ -370,13 +370,13 @@ TEST_CASE("test split") {
x = zeros({8, 12});
out = split(x, 2);
CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), std::vector<int>{4, 12});
CHECK_EQ(out[1].shape(), std::vector<int>{4, 12});
CHECK_EQ(out[0].shape(), Shape{4, 12});
CHECK_EQ(out[1].shape(), Shape{4, 12});
out = split(x, 3, 1);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].shape(), std::vector<int>{8, 4});
CHECK_EQ(out[1].shape(), std::vector<int>{8, 4});
CHECK_EQ(out[2].shape(), std::vector<int>{8, 4});
CHECK_EQ(out[0].shape(), Shape{8, 4});
CHECK_EQ(out[1].shape(), Shape{8, 4});
CHECK_EQ(out[2].shape(), Shape{8, 4});
out = split(x, std::vector<int>{});
CHECK_EQ(out.size(), 1);
@@ -384,25 +384,25 @@ TEST_CASE("test split") {
out = split(x, {3, 7});
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].shape(), std::vector<int>{3, 12});
CHECK_EQ(out[1].shape(), std::vector<int>{4, 12});
CHECK_EQ(out[2].shape(), std::vector<int>{1, 12});
CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{4, 12});
CHECK_EQ(out[2].shape(), Shape{1, 12});
out = split(x, std::vector<int>{20});
CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), std::vector<int>{8, 12});
CHECK_EQ(out[1].shape(), std::vector<int>{0, 12});
CHECK_EQ(out[0].shape(), Shape{8, 12});
CHECK_EQ(out[1].shape(), Shape{0, 12});
// Negative indices
out = split(x, std::vector<int>{-5});
CHECK_EQ(out[0].shape(), std::vector<int>{3, 12});
CHECK_EQ(out[1].shape(), std::vector<int>{5, 12});
CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{5, 12});
// Different axis
out = split(x, std::vector<int>{2, 8}, 1);
CHECK_EQ(out[0].shape(), std::vector<int>{8, 2});
CHECK_EQ(out[1].shape(), std::vector<int>{8, 6});
CHECK_EQ(out[2].shape(), std::vector<int>{8, 4});
CHECK_EQ(out[0].shape(), Shape{8, 2});
CHECK_EQ(out[1].shape(), Shape{8, 6});
CHECK_EQ(out[2].shape(), Shape{8, 4});
// Out of order indices
x = arange(5);
@@ -420,18 +420,18 @@ TEST_CASE("test swap and move axes") {
a = zeros({2});
CHECK_THROWS(swapaxes(a, 0, 1));
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
CHECK_EQ(swapaxes(a, 0, 0).shape(), Shape{2});
CHECK_EQ(swapaxes(a, -1, -1).shape(), Shape{2});
a = zeros({2, 3, 4});
CHECK_THROWS(swapaxes(a, 0, -4));
CHECK_THROWS(swapaxes(a, 0, 3));
CHECK_THROWS(swapaxes(a, 3, 0));
CHECK_THROWS(swapaxes(a, -4, 0));
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
CHECK_EQ(swapaxes(a, 0, 2).shape(), Shape{4, 3, 2});
CHECK_EQ(swapaxes(a, 0, 1).shape(), Shape{3, 2, 4});
CHECK_EQ(swapaxes(a, 0, -1).shape(), Shape{4, 3, 2});
CHECK_EQ(swapaxes(a, -2, 2).shape(), Shape{2, 4, 3});
// Test moveaxis
a = array(0.0);
@@ -439,36 +439,36 @@ TEST_CASE("test swap and move axes") {
a = zeros({2});
CHECK_THROWS(moveaxis(a, 0, 1));
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
CHECK_EQ(moveaxis(a, 0, 0).shape(), Shape{2});
CHECK_EQ(moveaxis(a, -1, -1).shape(), Shape{2});
a = zeros({2, 3, 4});
CHECK_THROWS(moveaxis(a, 0, -4));
CHECK_THROWS(moveaxis(a, 0, 3));
CHECK_THROWS(moveaxis(a, 3, 0));
CHECK_THROWS(moveaxis(a, -4, 0));
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
CHECK_EQ(moveaxis(a, 0, 2).shape(), Shape{3, 4, 2});
CHECK_EQ(moveaxis(a, 0, 1).shape(), Shape{3, 2, 4});
CHECK_EQ(moveaxis(a, 0, -1).shape(), Shape{3, 4, 2});
CHECK_EQ(moveaxis(a, -2, 2).shape(), Shape{2, 4, 3});
}
TEST_CASE("test transpose") {
array x(1);
auto y = transpose(x);
CHECK_EQ(y.shape(), std::vector<int>{});
CHECK_EQ(y.shape(), Shape{});
CHECK_EQ(y.item<int>(), 1);
CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);
CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
x = array({1}, {1});
y = transpose(x);
CHECK_EQ(y.shape(), std::vector<int>{1});
CHECK_EQ(y.shape(), Shape{1});
CHECK_EQ(y.item<int>(), 1);
// Negative indices
y = transpose(x, {-1});
CHECK_EQ(y.shape(), std::vector<int>{1});
CHECK_EQ(y.shape(), Shape{1});
CHECK_EQ(y.item<int>(), 1);
CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
@@ -477,24 +477,24 @@ TEST_CASE("test transpose") {
// Works with empty array
x = array({});
y = transpose(x);
CHECK_EQ(y.shape(), std::vector<int>{0});
CHECK_EQ(y.shape(), Shape{0});
y.eval();
CHECK_EQ(y.size(), 0);
x = array({1, 2, 3, 4, 5, 6}, {2, 3});
y = transpose(x);
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
CHECK_EQ(y.shape(), Shape{3, 2});
y = transpose(x, {-1, 0});
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
CHECK_EQ(y.shape(), Shape{3, 2});
y = transpose(x, {-1, -2});
CHECK_EQ(y.shape(), std::vector<int>{3, 2});
CHECK_EQ(y.shape(), Shape{3, 2});
y.eval();
CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item<bool>());
y = transpose(x, {0, 1});
CHECK_EQ(y.shape(), std::vector<int>{2, 3});
CHECK_EQ(y.shape(), Shape{2, 3});
CHECK(array_equal(y, x).item<bool>());
y = transpose(x, {0, -1});
CHECK_EQ(y.shape(), std::vector<int>{2, 3});
CHECK_EQ(y.shape(), Shape{2, 3});
CHECK(array_equal(y, x).item<bool>());
CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument);
@@ -505,19 +505,19 @@ TEST_CASE("test transpose") {
x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2});
y = transpose(x);
CHECK_EQ(y.shape(), std::vector<int>{2, 3, 2});
CHECK_EQ(y.shape(), Shape{2, 3, 2});
auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2});
CHECK(array_equal(y, expected).item<bool>());
y = transpose(x, {0, 1, 2});
CHECK_EQ(y.shape(), std::vector<int>{2, 3, 2});
CHECK_EQ(y.shape(), Shape{2, 3, 2});
CHECK(array_equal(y, x).item<bool>());
y = transpose(x, {1, 0, 2});
CHECK_EQ(y.shape(), std::vector<int>{3, 2, 2});
CHECK_EQ(y.shape(), Shape{3, 2, 2});
expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2});
CHECK(array_equal(y, expected).item<bool>());
y = transpose(x, {0, 2, 1});
CHECK_EQ(y.shape(), std::vector<int>{2, 2, 3});
CHECK_EQ(y.shape(), Shape{2, 2, 3});
expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3});
CHECK(array_equal(y, expected).item<bool>());
@@ -542,7 +542,7 @@ TEST_CASE("test comparison ops") {
array y({});
auto z = x == y;
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.shape(), std::vector<int>{0});
CHECK_EQ(z.shape(), Shape{0});
}
// Basic cases
@@ -631,7 +631,7 @@ TEST_CASE("test comparison ops") {
auto y = zeros({2, 1});
auto z = equal(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
CHECK_EQ(z.shape(), Shape{2, 2});
auto expected = array({true, true, true, true}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
@@ -639,7 +639,7 @@ TEST_CASE("test comparison ops") {
y = array({1.0, 2.0}, {2, 1});
z = equal(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
CHECK_EQ(z.shape(), Shape{2, 2});
expected = array({true, false, false, true}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
@@ -769,15 +769,15 @@ TEST_CASE("test reduction ops") {
CHECK_THROWS_AS(sum(x, 0), std::out_of_range);
CHECK_THROWS_AS(sum(x, -1), std::out_of_range);
out = sum(x, std::vector<int>{});
CHECK_EQ(out.shape(), std::vector<int>{});
CHECK_EQ(out.shape(), Shape{});
CHECK_EQ(out.size(), 1);
x = array({});
out = sum(x);
CHECK_EQ(out.shape(), std::vector<int>{});
CHECK_EQ(out.shape(), Shape{});
CHECK_EQ(out.size(), 1);
out = sum(x, true);
CHECK_EQ(out.shape(), std::vector<int>{1});
CHECK_EQ(out.shape(), Shape{1});
out = sum(x, std::vector<int>{});
CHECK_EQ(out.shape(), x.shape());
@@ -788,7 +788,7 @@ TEST_CASE("test reduction ops") {
CHECK_EQ(out.ndim(), 0);
out = sum(x, -1, true);
CHECK_EQ(out.ndim(), 1);
CHECK_EQ(out.shape(), std::vector<int>{1});
CHECK_EQ(out.shape(), Shape{1});
CHECK_THROWS_AS(sum(x, 1), std::out_of_range);
CHECK_THROWS_AS(sum(x, -2), std::out_of_range);
@@ -797,21 +797,21 @@ TEST_CASE("test reduction ops") {
x = zeros({2, 3, 4});
out = sum(x, {0, 2});
CHECK_EQ(out.shape(), std::vector<int>{3});
CHECK_EQ(out.shape(), Shape{3});
out = sum(x, std::vector<int>{});
CHECK_EQ(out.shape(), x.shape());
out = sum(x, {0, -1});
CHECK_EQ(out.shape(), std::vector<int>{3});
CHECK_EQ(out.shape(), Shape{3});
out = sum(x, {0, -1}, true);
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out.shape(), Shape{1, 3, 1});
out = sum(x, true);
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(out.shape(), Shape{1, 1, 1});
out = sum(x);
CHECK_EQ(out.shape(), std::vector<int>{});
CHECK_EQ(out.shape(), Shape{});
CHECK_THROWS_AS(sum(x, 3), std::out_of_range);
CHECK_THROWS_AS(sum(x, -4), std::out_of_range);
@@ -986,7 +986,7 @@ TEST_CASE("test reduction ops") {
std::vector<float> nums = {0.0f, 1.0f, 2.0f, 3.0f};
x = array(nums.data(), {2, 2});
auto y = logsumexp(x, {0, 1}, true);
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
CHECK_EQ(y.shape(), Shape{1, 1});
auto result = std::log(
std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) +
std::exp(nums[3]));
@@ -1594,7 +1594,7 @@ TEST_CASE("test arithmetic binary ops") {
x = array({1.0, 2.0, 3.0}, {1, 3});
y = array({1.0, 2.0, 3.0}, {1, 3});
z = add(x, y);
CHECK_EQ(z.shape(), std::vector<int>{1, 3});
CHECK_EQ(z.shape(), Shape{1, 3});
auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3}));
CHECK(eq.item<bool>());
@@ -1626,13 +1626,13 @@ TEST_CASE("test arithmetic binary ops") {
x = array({1.0, 2.0}, {1, 2});
y = array({1.0, 2.0}, {2, 1});
z = add(x, y);
CHECK_EQ(z.shape(), std::vector<int>{2, 2});
CHECK_EQ(z.shape(), Shape{2, 2});
eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2}));
CHECK(eq.item<bool>());
x = ones({3, 2, 1});
z = x + 2.0;
CHECK_EQ(z.shape(), std::vector<int>{3, 2, 1});
CHECK_EQ(z.shape(), Shape{3, 2, 1});
eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1}));
CHECK(eq.item<bool>());
@@ -1642,7 +1642,7 @@ TEST_CASE("test arithmetic binary ops") {
z = x + y;
z.eval();
CHECK_EQ(z.size(), 0);
CHECK_EQ(z.shape(), std::vector<int>{0});
CHECK_EQ(z.shape(), Shape{0});
// Check subtraction
x = array({3, 2, 1});
@@ -1725,46 +1725,46 @@ TEST_CASE("test arithmetic binary ops") {
TEST_CASE("test broadcast") {
auto s = broadcast_shapes({1}, {1, 2});
CHECK_EQ(s, std::vector<int>{1, 2});
CHECK_EQ(s, Shape{1, 2});
s = broadcast_shapes({1, 2}, {1});
CHECK_EQ(s, std::vector<int>{1, 2});
CHECK_EQ(s, Shape{1, 2});
s = broadcast_shapes({2, 2}, {});
CHECK_EQ(s, std::vector<int>{2, 2});
CHECK_EQ(s, Shape{2, 2});
s = broadcast_shapes({}, {1, 1});
CHECK_EQ(s, std::vector<int>{1, 1});
CHECK_EQ(s, Shape{1, 1});
s = broadcast_shapes({1, 2, 1}, {2});
CHECK_EQ(s, std::vector<int>{1, 2, 2});
CHECK_EQ(s, Shape{1, 2, 2});
s = broadcast_shapes({2}, {1, 2, 1});
CHECK_EQ(s, std::vector<int>{1, 2, 2});
CHECK_EQ(s, Shape{1, 2, 2});
s = broadcast_shapes({2, 2, 2}, {1, 2, 1});
CHECK_EQ(s, std::vector<int>{2, 2, 2});
CHECK_EQ(s, Shape{2, 2, 2});
s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1});
CHECK_EQ(s, std::vector<int>{2, 2, 2, 1});
CHECK_EQ(s, Shape{2, 2, 2, 1});
s = broadcast_shapes({0}, {0, 0});
CHECK_EQ(s, std::vector<int>{0, 0});
CHECK_EQ(s, Shape{0, 0});
CHECK_EQ(broadcast_shapes({}, {0}), std::vector<int>{0});
CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
s = broadcast_shapes({5, 0}, {0, 5, 0});
CHECK_EQ(s, std::vector<int>{0, 5, 0});
CHECK_EQ(s, Shape{0, 5, 0});
CHECK_EQ(broadcast_shapes({}, {0}), std::vector<int>{0});
CHECK_EQ(broadcast_shapes({1}, {0}), std::vector<int>{0});
CHECK_EQ(broadcast_shapes({1}, {0}), std::vector<int>{0});
CHECK_EQ(broadcast_shapes({1}, {0, 0}), std::vector<int>{0, 0});
CHECK_EQ(broadcast_shapes({1, 1}, {0}), std::vector<int>{1, 0});
CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), std::vector<int>{0, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), std::vector<int>{2, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), std::vector<int>{2, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), std::vector<int>{1, 2, 0});
CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
CHECK_EQ(broadcast_shapes({1}, {0, 0}), Shape{0, 0});
CHECK_EQ(broadcast_shapes({1, 1}, {0}), Shape{1, 0});
CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), Shape{0, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), Shape{2, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), Shape{2, 0});
CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), Shape{1, 2, 0});
CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument);
CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument);
@@ -1778,19 +1778,19 @@ TEST_CASE("test broadcast") {
CHECK_EQ(broadcast_to(x, {1, 1}).item<float>(), 2.3f);
x = broadcast_to(x, {5, 1});
CHECK_EQ(x.shape(), std::vector<int>{5, 1});
CHECK_EQ(x.shape(), Shape{5, 1});
x.eval();
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0});
CHECK_EQ(x.strides(), Strides{0, 0});
CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument);
x = broadcast_to(x, {5, 5});
CHECK_EQ(x.shape(), std::vector<int>{5, 5});
CHECK_EQ(x.shape(), Shape{5, 5});
x = zeros({2, 1, 2});
x = broadcast_to(x, {4, 2, 1, 2});
CHECK_EQ(x.shape(), std::vector<int>{4, 2, 1, 2});
CHECK_EQ(x.shape(), Shape{4, 2, 1, 2});
x.eval();
CHECK_EQ(x.strides(), std::vector<size_t>{0, 2, 0, 1});
CHECK_EQ(x.strides(), Strides{0, 2, 0, 1});
// Broadcast on empty arrays works as expected
x = array({});
@@ -1801,29 +1801,29 @@ TEST_CASE("test broadcast") {
auto y = broadcast_to(x, {0});
eval(y);
CHECK_EQ(y.size(), 0);
CHECK_EQ(y.shape(), std::vector<int>{0});
CHECK_EQ(y.shape(), Shape{0});
x = array({1, 2}, {2, 1});
y = broadcast_to(x, {2, 0});
eval(y);
CHECK_EQ(y.size(), 0);
CHECK_EQ(y.shape(), std::vector<int>{2, 0});
CHECK_EQ(y.shape(), Shape{2, 0});
// Check repeat application works
x = zeros({2});
x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2});
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
CHECK_EQ(x.shape(), Shape{2, 2});
x.eval();
CHECK_EQ(x.strides(), std::vector<size_t>{0, 1});
CHECK_EQ(x.strides(), Strides{0, 1});
x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2});
CHECK_EQ(x.shape(), std::vector<int>{2, 2, 2});
CHECK_EQ(x.shape(), Shape{2, 2, 2});
x.eval();
CHECK_EQ(x.strides(), std::vector<size_t>{0, 0, 1});
CHECK_EQ(x.strides(), Strides{0, 0, 1});
// Broadcast on transposed array works
x = array({0, 1, 2, 3, 4, 5}, {2, 3});
x = broadcast_to(transpose(x), {2, 3, 2});
CHECK_EQ(x.shape(), std::vector<int>{2, 3, 2});
CHECK_EQ(x.shape(), Shape{2, 3, 2});
y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2});
CHECK(array_equal(x, y).item<bool>());
@@ -1867,16 +1867,16 @@ TEST_CASE("test gather") {
auto x = arange(20);
auto y = arange(10);
auto out = gather(x, y, 0, {1});
CHECK_EQ(out.shape(), std::vector<int>{10, 1});
CHECK_EQ(out.shape(), Shape{10, 1});
CHECK(array_equal(reshape(out, {-1}), y).item<bool>());
out = gather(x, array({15}, uint32), 0, {1});
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
CHECK_EQ(out.shape(), Shape{1, 1});
CHECK_EQ(out.item<int32_t>(), 15);
// No index gather works
out = gather(x, {}, std::vector<int>{}, {10});
CHECK_EQ(out.shape(), std::vector<int>{10});
CHECK_EQ(out.shape(), Shape{10});
CHECK(array_equal(out, arange(10)).item<bool>());
// Basic test of correctness with 2D input
@@ -1884,13 +1884,13 @@ TEST_CASE("test gather") {
x = reshape(x, {4, 32});
y = array({0, 1}, uint32);
out = gather(x, y, 0, {1, 32});
CHECK_EQ(out.shape(), std::vector<int>{2, 1, 32});
CHECK_EQ(out.shape(), Shape{2, 1, 32});
CHECK(array_equal(reshape(out, {64}), arange(64)).item<bool>());
x = reshape(x, {64, 2});
y = array({0}, uint32);
out = gather(x, y, 0, {64, 1});
CHECK_EQ(out.shape(), std::vector<int>{1, 64, 1});
CHECK_EQ(out.shape(), Shape{1, 64, 1});
CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item<bool>());
// Basic test of correctness with 3D input
@@ -1898,7 +1898,7 @@ TEST_CASE("test gather") {
x = reshape(x, {8, 4, 8});
y = array({0}, uint32);
out = gather(x, y, 0, {8, 1, 1});
CHECK_EQ(out.shape(), std::vector<int>{1, 8, 1, 1});
CHECK_EQ(out.shape(), Shape{1, 8, 1, 1});
CHECK(
array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item<bool>());
@@ -1913,10 +1913,10 @@ TEST_CASE("test take") {
// Empty takes
auto empty = astype(array({}), int32);
auto z = take(array({1}), empty);
CHECK_EQ(z.shape(), std::vector<int>{0});
CHECK_EQ(z.shape(), Shape{0});
empty = reshape(empty, {1, 0, 1});
z = take(array({1}), empty);
CHECK_EQ(z.shape(), std::vector<int>{1, 0, 1});
CHECK_EQ(z.shape(), Shape{1, 0, 1});
CHECK_THROWS(take(array({}), array(1)));
@@ -1926,7 +1926,7 @@ TEST_CASE("test take") {
// Take a single row
auto x = reshape(arange(256), {8, 4, 8});
z = take(x, array({0}, uint32), 0);
CHECK_EQ(z.shape(), std::vector<int>{1, 4, 8});
CHECK_EQ(z.shape(), Shape{1, 4, 8});
z = reshape(z, {32});
CHECK(array_equal(z, arange(32)).item<bool>());
@@ -2017,12 +2017,12 @@ TEST_CASE("test take along axis") {
out = take_along_axis(a, reshape(array({1}), {1, 1}), 0);
eval(out); // Make sure it runs
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
CHECK_EQ(out.shape(), Shape{1, 0});
auto inds = reshape(astype(array({}), int32), {1, 0});
out = take_along_axis(a, inds, 0);
eval(out); // Make sure it runs
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
CHECK_EQ(out.shape(), Shape{1, 0});
a = array({1, 2, 3, 4}, {2, 2});
inds = array({0, 1}, {1, 2});
@@ -2084,7 +2084,7 @@ TEST_CASE("test put along axis") {
auto inds = reshape(astype(array({}), int32), {1, 0});
out = take_along_axis(a, inds, 0);
eval(out); // Make sure it runs
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
CHECK_EQ(out.shape(), Shape{1, 0});
a = array({1, 2, 3, 4}, {2, 2});
inds = array({0, 1}, {1, 2});
@@ -2506,9 +2506,9 @@ TEST_CASE("test scan op") {
TEST_CASE("test pad") {
auto x = zeros({1, 2, 3});
CHECK_EQ(pad(x, 1).shape(), std::vector<int>{3, 4, 5});
CHECK_EQ(pad(x, {0, 1}).shape(), std::vector<int>{2, 3, 4});
CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), std::vector<int>{3, 5, 7});
CHECK_EQ(pad(x, 1).shape(), Shape{3, 4, 5});
CHECK_EQ(pad(x, {0, 1}).shape(), Shape{2, 3, 4});
CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), Shape{3, 5, 7});
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
auto padded_x = pad(x, 1);
@@ -2647,20 +2647,20 @@ TEST_CASE("test where") {
TEST_CASE("test stack") {
auto x = array({});
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 0});
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{0, 1});
CHECK_EQ(stack({x}, 0).shape(), Shape{1, 0});
CHECK_EQ(stack({x}, 1).shape(), Shape{0, 1});
x = array({1, 2, 3}, {3});
CHECK_EQ(stack({x}, 0).shape(), std::vector<int>{1, 3});
CHECK_EQ(stack({x}, 1).shape(), std::vector<int>{3, 1});
CHECK_EQ(stack({x}, 0).shape(), Shape{1, 3});
CHECK_EQ(stack({x}, 1).shape(), Shape{3, 1});
auto y = array({4, 5, 6}, {3});
auto z = std::vector<array>{x, y};
CHECK_EQ(stack(z).shape(), std::vector<int>{2, 3});
CHECK_EQ(stack(z, 0).shape(), std::vector<int>{2, 3});
CHECK_EQ(stack(z, 1).shape(), std::vector<int>{3, 2});
CHECK_EQ(stack(z, -1).shape(), std::vector<int>{3, 2});
CHECK_EQ(stack(z, -2).shape(), std::vector<int>{2, 3});
CHECK_EQ(stack(z).shape(), Shape{2, 3});
CHECK_EQ(stack(z, 0).shape(), Shape{2, 3});
CHECK_EQ(stack(z, 1).shape(), Shape{3, 2});
CHECK_EQ(stack(z, -1).shape(), Shape{3, 2});
CHECK_EQ(stack(z, -2).shape(), Shape{2, 3});
CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
@@ -2676,20 +2676,20 @@ TEST_CASE("test stack") {
TEST_CASE("test eye") {
auto eye_3 = eye(3);
CHECK_EQ(eye_3.shape(), std::vector<int>{3, 3});
CHECK_EQ(eye_3.shape(), Shape{3, 3});
auto expected_eye_3 =
array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
auto eye_3x2 = eye(3, 2);
CHECK_EQ(eye_3x2.shape(), std::vector<int>{3, 2});
CHECK_EQ(eye_3x2.shape(), Shape{3, 2});
auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
}
TEST_CASE("test tri") {
auto _tri = tri(4, 4, 0, float32);
CHECK_EQ(_tri.shape(), std::vector<int>{4, 4});
CHECK_EQ(_tri.shape(), Shape{4, 4});
auto expected_tri = array(
{1.0f,
0.0f,
@@ -2712,8 +2712,8 @@ TEST_CASE("test tri") {
}
TEST_CASE("test tril") {
auto _tril = tril(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_tril.shape(), std::vector<int>{4, 4});
auto _tril = tril(full({4, 4}, 2.0f, float32), 0);
CHECK_EQ(_tril.shape(), Shape{4, 4});
auto expected_tri = array(
{2.0f,
0.0f,
@@ -2736,8 +2736,8 @@ TEST_CASE("test tril") {
}
TEST_CASE("test triu") {
auto _triu = triu(full(std::vector<int>{4, 4}, 2.0f, float32), 0);
CHECK_EQ(_triu.shape(), std::vector<int>{4, 4});
auto _triu = triu(full({4, 4}, 2.0f, float32), 0);
CHECK_EQ(_triu.shape(), Shape{4, 4});
auto expected_tri = array(
{2.0f,
2.0f,
@@ -2761,7 +2761,7 @@ TEST_CASE("test triu") {
TEST_CASE("test identity") {
auto id_4 = identity(4);
CHECK_EQ(id_4.shape(), std::vector<int>{4, 4});
CHECK_EQ(id_4.shape(), Shape{4, 4});
auto expected_id_4 = array(
{1.0f,
0.0f,
@@ -2785,7 +2785,7 @@ TEST_CASE("test identity") {
TEST_CASE("test eye with positive k offset") {
auto eye_3_k1 = eye(3, 4, 1);
CHECK_EQ(eye_3_k1.shape(), std::vector<int>{3, 4});
CHECK_EQ(eye_3_k1.shape(), Shape{3, 4});
auto expected_eye_3_k1 = array(
{0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{3, 4});
@@ -2794,7 +2794,7 @@ TEST_CASE("test eye with positive k offset") {
TEST_CASE("test eye with negative k offset") {
auto eye_4_k_minus1 = eye(4, 3, -1);
CHECK_EQ(eye_4_k_minus1.shape(), std::vector<int>{4, 3});
CHECK_EQ(eye_4_k_minus1.shape(), Shape{4, 3});
auto expected_eye_4_k_minus1 = array(
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{4, 3});
@@ -2844,9 +2844,9 @@ TEST_CASE("test quantize dequantize") {
for (int i = 2; i <= 8; i *= 2) {
int el_per_int = 32 / i;
auto [x_q, scales, biases] = quantize(x, 128, i);
CHECK_EQ(x_q.shape(), std::vector<int>{128, 512 / el_per_int});
CHECK_EQ(scales.shape(), std::vector<int>{128, 4});
CHECK_EQ(biases.shape(), std::vector<int>{128, 4});
CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int});
CHECK_EQ(scales.shape(), Shape{128, 4});
CHECK_EQ(biases.shape(), Shape{128, 4});
auto x_hat = dequantize(x_q, scales, biases, 128, i);
auto max_diff = max(abs(x - x_hat)).item<float>();
@@ -3081,7 +3081,7 @@ TEST_CASE("test diagonal") {
out = diagonal(x, -5, 0, 1);
eval(out);
CHECK_EQ(out.shape(), std::vector<int>{0});
CHECK_EQ(out.shape(), Shape{0});
x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
out = diagonal(x, 1, 0, 1);
@@ -3337,17 +3337,17 @@ TEST_CASE("test atleast_1d") {
auto x = array(1);
auto out = atleast_1d(x);
CHECK_EQ(out.ndim(), 1);
CHECK_EQ(out.shape(), std::vector<int>{1});
CHECK_EQ(out.shape(), Shape{1});
x = array({1, 2, 3}, {3});
out = atleast_1d(x);
CHECK_EQ(out.ndim(), 1);
CHECK_EQ(out.shape(), std::vector<int>{3});
CHECK_EQ(out.shape(), Shape{3});
x = array({1, 2, 3}, {3, 1});
out = atleast_1d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
CHECK_EQ(out.shape(), Shape{3, 1});
}
TEST_CASE("test atleast_1d vector") {
@@ -3356,28 +3356,28 @@ TEST_CASE("test atleast_1d vector") {
auto out = atleast_1d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 1);
CHECK_EQ(out[0].shape(), std::vector<int>{1});
CHECK_EQ(out[0].shape(), Shape{1});
CHECK_EQ(out[1].ndim(), 1);
CHECK_EQ(out[1].shape(), std::vector<int>{3});
CHECK_EQ(out[1].shape(), Shape{3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
CHECK_EQ(out[2].shape(), Shape{3, 1});
}
TEST_CASE("test atleast_2d") {
auto x = array(1);
auto out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{1, 1});
CHECK_EQ(out.shape(), Shape{1, 1});
x = array({1, 2, 3}, {3});
out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{1, 3});
CHECK_EQ(out.shape(), Shape{1, 3});
x = array({1, 2, 3}, {3, 1});
out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
CHECK_EQ(out.shape(), Shape{3, 1});
}
TEST_CASE("test atleast_2d vector") {
@@ -3386,28 +3386,28 @@ TEST_CASE("test atleast_2d vector") {
auto out = atleast_2d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 2);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
CHECK_EQ(out[0].shape(), Shape{1, 1});
CHECK_EQ(out[1].ndim(), 2);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
CHECK_EQ(out[1].shape(), Shape{1, 3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
CHECK_EQ(out[2].shape(), Shape{3, 1});
}
TEST_CASE("test atleast_3d") {
auto x = array(1);
auto out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(out.shape(), Shape{1, 1, 1});
x = array({1, 2, 3}, {3});
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out.shape(), Shape{1, 3, 1});
x = array({1, 2, 3}, {3, 1});
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
CHECK_EQ(out.shape(), Shape{3, 1, 1});
}
TEST_CASE("test atleast_3d vector") {
@@ -3416,11 +3416,11 @@ TEST_CASE("test atleast_3d vector") {
auto out = atleast_3d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 3);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(out[0].shape(), Shape{1, 1, 1});
CHECK_EQ(out[1].ndim(), 3);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out[1].shape(), Shape{1, 3, 1});
CHECK_EQ(out[2].ndim(), 3);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
CHECK_EQ(out[2].shape(), Shape{3, 1, 1});
}
TEST_CASE("test topk") {