Add Masked Scatter (#2663)

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
CCYeh
2025-11-19 23:53:32 +01:00
committed by GitHub
parent 7f4b7e553c
commit b3825ac149
26 changed files with 1099 additions and 51 deletions

View File

@@ -7,6 +7,7 @@
#include "doctest/doctest.h"
#include "mlx/backend/cuda/cuda.h"
#include "mlx/mlx.h"
using namespace mlx::core;
@@ -2435,6 +2436,49 @@ TEST_CASE("test scatter") {
}
}
TEST_CASE("test masked_scatter") {
if (cu::is_available()) {
INFO("Skipping masked_scatter cuda ops tests");
return;
}
// Wrong mask dtype
CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));
// Mask must be broadcastable to self array
CHECK_THROWS(masked_scatter(
array({1, 2, 3, 4}, {2, 2}),
array({false, true, true, false}, {4, 1}),
array({1, 2})));
// 1D mask
{
auto self = zeros({4}, int32);
auto mask = array({true, true, false, true});
auto source = array({1, 2, 4});
auto out = masked_scatter(self, mask, source);
CHECK(array_equal(out, array({1, 2, 0, 4})).item<bool>());
}
// Empty mask
{
auto self = zeros({4}, int32);
auto mask = array({false, false, false, false});
auto source = array({1, 2, 4});
auto out = masked_scatter(self, mask, source);
CHECK(array_equal(out, self).item<bool>());
}
// Broadcasted mask
{
auto self = zeros({2, 2}, int32);
auto mask = array({true, false});
auto source = array({5, 6, 7, 8}, {2, 2});
auto out = masked_scatter(self, mask, source);
CHECK(array_equal(out, array({5, 6, 0, 0}, {2, 2})).item<bool>());
}
}
TEST_CASE("test is positive infinity") {
array x(1.0f);
CHECK_FALSE(isposinf(x).item<bool>());