More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -395,7 +395,7 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{8, 4});
CHECK_EQ(out[2].shape(), Shape{8, 4});
out = split(x, std::vector<int>{});
out = split(x, Shape{});
CHECK_EQ(out.size(), 1);
CHECK_EQ(out[0].shape(), x.shape());
@@ -405,25 +405,25 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{4, 12});
CHECK_EQ(out[2].shape(), Shape{1, 12});
out = split(x, std::vector<int>{20});
out = split(x, Shape{20});
CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), Shape{8, 12});
CHECK_EQ(out[1].shape(), Shape{0, 12});
// Negative indices
out = split(x, std::vector<int>{-5});
out = split(x, Shape{-5});
CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{5, 12});
// Different axis
out = split(x, std::vector<int>{2, 8}, 1);
out = split(x, {2, 8}, 1);
CHECK_EQ(out[0].shape(), Shape{8, 2});
CHECK_EQ(out[1].shape(), Shape{8, 6});
CHECK_EQ(out[2].shape(), Shape{8, 4});
// Out of order indices
x = arange(5);
out = split(x, std::vector<int>{2, 1, 2});
out = split(x, {2, 1, 2});
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
CHECK(array_equal(out[1], array({})).item<bool>());
CHECK(array_equal(out[2], array({1})).item<bool>());

View File

@@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
CHECK_THROWS(categorical(logits, -3));
// Invalid requested shapes
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
CHECK_THROWS(categorical(logits, 1, Shape{1}));
CHECK_THROWS(categorical(logits, 1, Shape{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});

View File

@@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2};
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});