mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
fix scatter + test (#1202)
* fix scatter + test * fix test warnings * fix metal validation
This commit is contained in:
@@ -2062,6 +2062,18 @@ TEST_CASE("test scatter") {
|
||||
auto out = scatter(dst, idx, src, 0);
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
// 1D indices with 2D update
|
||||
{
|
||||
auto dst = zeros({3, 4}, int32);
|
||||
auto indices = {array({1}), array({2})};
|
||||
auto axes = {0, 1};
|
||||
auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
|
||||
auto out = scatter(dst, indices, updates, axes);
|
||||
auto expected =
|
||||
reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test is positive infinity") {
|
||||
|
Reference in New Issue
Block a user