fix scatter (#821)

This commit is contained in:
Awni Hannun
2024-03-12 11:42:07 -07:00
committed by GitHub
parent 366478c560
commit 8b7532b9ab
3 changed files with 12 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cmath>
#include <numeric>
@@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") {
inds = array({0, 1});
out = scatter_add(in, inds, updates, 0);
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
// 1D scatter
{
auto dst = zeros({2, 4}, int32);
auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
auto idx = array({1});
auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
auto out = scatter(dst, idx, src, 0);
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test is positive infinity") {