Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <numeric>
@@ -205,6 +204,46 @@ TEST_CASE("test split") {
CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
}
TEST_CASE("test swap and move axes") {
// Test swapaxes
array a(0.0);
CHECK_THROWS(swapaxes(a, 0, 0));
a = zeros({2});
CHECK_THROWS(swapaxes(a, 0, 1));
CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(swapaxes(a, 0, -4));
CHECK_THROWS(swapaxes(a, 0, 3));
CHECK_THROWS(swapaxes(a, 3, 0));
CHECK_THROWS(swapaxes(a, -4, 0));
CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector<int>{4, 3, 2});
CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
// Test moveaxis
a = array(0.0);
CHECK_THROWS(moveaxis(a, 0, 0));
a = zeros({2});
CHECK_THROWS(moveaxis(a, 0, 1));
CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector<int>{2});
CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector<int>{2});
a = zeros({2, 3, 4});
CHECK_THROWS(moveaxis(a, 0, -4));
CHECK_THROWS(moveaxis(a, 0, 3));
CHECK_THROWS(moveaxis(a, 3, 0));
CHECK_THROWS(moveaxis(a, -4, 0));
CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector<int>{3, 2, 4});
CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector<int>{3, 4, 2});
CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector<int>{2, 4, 3});
}
TEST_CASE("test transpose") {
array x(1);
auto y = transpose(x);
@@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") {
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
{4, 3});
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
}
}