mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
@@ -2435,6 +2436,49 @@ TEST_CASE("test scatter") {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test masked_scatter") {
|
||||
if (cu::is_available()) {
|
||||
INFO("Skipping masked_scatter cuda ops tests");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wrong mask dtype
|
||||
CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));
|
||||
|
||||
// Mask must be broadcastable to self array
|
||||
CHECK_THROWS(masked_scatter(
|
||||
array({1, 2, 3, 4}, {2, 2}),
|
||||
array({false, true, true, false}, {4, 1}),
|
||||
array({1, 2})));
|
||||
|
||||
// 1D mask
|
||||
{
|
||||
auto self = zeros({4}, int32);
|
||||
auto mask = array({true, true, false, true});
|
||||
auto source = array({1, 2, 4});
|
||||
auto out = masked_scatter(self, mask, source);
|
||||
CHECK(array_equal(out, array({1, 2, 0, 4})).item<bool>());
|
||||
}
|
||||
|
||||
// Empty mask
|
||||
{
|
||||
auto self = zeros({4}, int32);
|
||||
auto mask = array({false, false, false, false});
|
||||
auto source = array({1, 2, 4});
|
||||
auto out = masked_scatter(self, mask, source);
|
||||
CHECK(array_equal(out, self).item<bool>());
|
||||
}
|
||||
|
||||
// Broadcasted mask
|
||||
{
|
||||
auto self = zeros({2, 2}, int32);
|
||||
auto mask = array({true, false});
|
||||
auto source = array({5, 6, 7, 8}, {2, 2});
|
||||
auto out = masked_scatter(self, mask, source);
|
||||
CHECK(array_equal(out, array({5, 6, 0, 0}, {2, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test is positive infinity") {
|
||||
array x(1.0f);
|
||||
CHECK_FALSE(isposinf(x).item<bool>());
|
||||
|
||||
Reference in New Issue
Block a user