Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <numeric>
@@ -205,6 +204,46 @@ TEST_CASE("test split") {
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
}
TEST_CASE("test swap and move axes") {
// Test swapaxes
array a(0.0);
CHECK_THROWS(swapaxes(a, 0, 0));
a = zeros({2});
CHECK_THROWS(swapaxes(a, 0, 1));
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(swapaxes(a, 0, -4));
CHECK_THROWS(swapaxes(a, 0, 3));
CHECK_THROWS(swapaxes(a, 3, 0));
CHECK_THROWS(swapaxes(a, -4, 0));
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
// Test moveaxis
a = array(0.0);
CHECK_THROWS(moveaxis(a, 0, 0));
a = zeros({2});
CHECK_THROWS(moveaxis(a, 0, 1));
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(moveaxis(a, 0, -4));
CHECK_THROWS(moveaxis(a, 0, 3));
CHECK_THROWS(moveaxis(a, 3, 0));
CHECK_THROWS(moveaxis(a, -4, 0));
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
}
TEST_CASE("test transpose") {
array x(1);
auto y = transpose(x);
@@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") {
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{4, 3});
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
}
}

View File

@@ -248,3 +248,104 @@ TEST_CASE("test vmap creation ops") {
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap slice") {
{
auto fun = [](array in) { return slice(in, {4}, {8}, {2}); };
auto x = reshape(arange(16), {2, 8});
auto out = vmap(fun)(x);
auto expected = reshape(array({4, 6, 12, 14}), {2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
{
auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); };
auto x = reshape(arange(12), {2, 2, 3});
auto out = vmap(fun, 1, 0)(x);
auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test vmap concatenate") {
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{concatenate(inputs, 0)};
};
auto x = reshape(arange(4), {2, 2});
auto y = reshape(arange(4), {2, 2});
auto out = vmap(fun)({x, y})[0];
auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
out = vmap(fun, {1, 1})({x, y})[0];
expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
out = vmap(fun, {0, 1})({x, y})[0];
expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test vmap gather") {
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, -1})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
out = vmap(fun, {0, -1}, {3})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
}
{
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2};
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
}
}