mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
fix scatter (#821)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1919,6 +1919,16 @@ TEST_CASE("test scatter") {
|
||||
inds = array({0, 1});
|
||||
out = scatter_add(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
|
||||
|
||||
// 1D scatter
|
||||
{
|
||||
auto dst = zeros({2, 4}, int32);
|
||||
auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
|
||||
auto idx = array({1});
|
||||
auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
|
||||
auto out = scatter(dst, idx, src, 0);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test is positive infinity") {
|
||||
|
Reference in New Issue
Block a user