Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -1983,6 +1983,12 @@ TEST_CASE("test take") {
CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());
out = take(a, array({0, 1}), 1);
CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());
// Indices have wrong shape
a = zeros({2, 3, 4});
CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));
CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));
CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
}
TEST_CASE("test take along axis") {
@@ -2001,12 +2007,6 @@ TEST_CASE("test take along axis") {
out = take_along_axis(a, array({1}), -1);
CHECK_EQ(out.item<int>(), 1);
// Indices have wrong shape
a = zeros({2, 3, 4});
CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));
CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));
CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
// Empty arrays
a = reshape(array({}), {1, 0});
CHECK_THROWS(take_along_axis(a, array({1}), 0));
@@ -2057,6 +2057,48 @@ TEST_CASE("test take along axis") {
.item<bool>());
}
TEST_CASE("test put along axis") {
// No zero dim arrays
auto a = array(1);
auto v = array(1);
CHECK_THROWS(put_along_axis(a, array(0), v, 0));
// Index and array size mismatches
a = arange(5);
CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 1));
CHECK_THROWS(put_along_axis(a, array({1}, {1, 1}), array({0}), 0));
CHECK_THROWS(put_along_axis(a, array(1), array(0), -1));
auto expected = array({0, 0, 2, 3, 4});
auto out = put_along_axis(a, array({1}), array({0}), 0);
CHECK(array_equal(out, expected).item<bool>());
// Empty arrays
a = reshape(array({}), {1, 0});
CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 0));
auto inds = reshape(astype(array({}), int32), {1, 0});
out = take_along_axis(a, inds, 0);
eval(out); // Make sure it runs
CHECK_EQ(out.shape(), std::vector<int>{1, 0});
a = array({1, 2, 3, 4}, {2, 2});
inds = array({0, 1}, {1, 2});
out = put_along_axis(a, inds, array({0}), 0);
expected = array({0, 2, 3, 0}, {2, 2});
CHECK(array_equal(out, expected).item<bool>());
inds = array({0, 0, 1, 1}, {2, 2}, int32);
auto values = array({2, 3, 4, 5}, {2, 2}, int32);
out = put_along_axis(a, inds, values, 0);
CHECK(array_equal(out, array({2, 3, 4, 5}, {2, 2})).item<bool>());
inds = array({0, 1}, {2, 1});
out = put_along_axis(a, inds, array({0}), 1);
expected = array({0, 2, 3, 0}, {2, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test scatter") {
// More indices than dimensions
CHECK_THROWS(scatter(array(0), array({1}), array(1), 0));