More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2};
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});