Add Masked Scatter (#2663)

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
CCYeh
2025-11-19 23:53:32 +01:00
committed by GitHub
parent 7f4b7e553c
commit b3825ac149
26 changed files with 1099 additions and 51 deletions

View File

@@ -13,6 +13,8 @@
#include "mlx/graph_utils.h"
#include "mlx/mlx.h"
#include "mlx/backend/cuda/cuda.h"
using namespace mlx::core;
TEST_CASE("test stop gradient") {
@@ -1353,3 +1355,45 @@ TEST_CASE("test grad dynamic slices") {
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
}
}
TEST_CASE("test masked_scatter autograd") {
if (cu::is_available()) {
INFO("Skipping masked_scatter cuda autograd tests");
return;
}
// Test jvp
{
auto self = array({10.f, 20.f, 30.f, 40.f}, {4});
auto mask = array({false, true, false, true}, bool_);
auto src = array({7.f, 8.f}, {2});
auto self_tan = array({1.f, 2.f, 3.f, 4.f}, {4});
auto src_tan = array({9.f, 11.f}, {2});
auto fun = [&mask](const std::vector<array>& in) {
return std::vector<array>{masked_scatter(in[0], mask, in[1])};
};
auto outs = jvp(fun, {self, src}, {self_tan, src_tan}).second;
CHECK_EQ(outs.size(), 1);
CHECK(array_equal(outs[0], array({1.f, 9.f, 3.f, 11.f}, {4})).item<bool>());
}
// Test vjp
{
auto self = array({10.f, 20.f, 30.f, 40.f}, {4});
auto mask = array({true, false, false, true}, bool_);
auto src = array({7.f, 8.f}, {2});
auto f_sum = [&mask](const std::vector<array>& xs) {
return std::vector<array>{sum(masked_scatter(xs[0], mask, xs[1]))};
};
auto v = vjp(f_sum, {self, src}, {array(1.f)});
const auto& grads = v.second;
CHECK(array_equal(grads[0], array({0.f, 1.f, 1.f, 0.f}, {4})).item<bool>());
CHECK(array_equal(grads[1], array({1.f, 1.f}, {2})).item<bool>());
}
}