fix scatter + test (#1202)

* fix scatter + test

* fix test warnings

* fix metal validation
This commit is contained in:
Awni Hannun
2024-06-11 14:35:12 -07:00
committed by GitHub
parent 709ccc6800
commit df964132fb
5 changed files with 54 additions and 9 deletions

View File

@@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") {
auto out = scatter(dst, idx, src, 0);
CHECK(array_equal(out, expected).item<bool>());
}
// 1D indices with 2D update
{
auto dst = zeros({3, 4}, int32);
auto indices = {array({1}), array({2})};
auto axes = {0, 1};
auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
auto out = scatter(dst, indices, updates, axes);
auto expected =
reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test is positive infinity") {