mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test stop gradient") {
|
||||
@@ -1353,3 +1355,45 @@ TEST_CASE("test grad dynamic slices") {
|
||||
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test masked_scatter autograd") {
|
||||
if (cu::is_available()) {
|
||||
INFO("Skipping masked_scatter cuda autograd tests");
|
||||
return;
|
||||
}
|
||||
|
||||
// Test jvp
|
||||
{
|
||||
auto self = array({10.f, 20.f, 30.f, 40.f}, {4});
|
||||
auto mask = array({false, true, false, true}, bool_);
|
||||
auto src = array({7.f, 8.f}, {2});
|
||||
|
||||
auto self_tan = array({1.f, 2.f, 3.f, 4.f}, {4});
|
||||
auto src_tan = array({9.f, 11.f}, {2});
|
||||
|
||||
auto fun = [&mask](const std::vector<array>& in) {
|
||||
return std::vector<array>{masked_scatter(in[0], mask, in[1])};
|
||||
};
|
||||
|
||||
auto outs = jvp(fun, {self, src}, {self_tan, src_tan}).second;
|
||||
CHECK_EQ(outs.size(), 1);
|
||||
CHECK(array_equal(outs[0], array({1.f, 9.f, 3.f, 11.f}, {4})).item<bool>());
|
||||
}
|
||||
|
||||
// Test vjp
|
||||
{
|
||||
auto self = array({10.f, 20.f, 30.f, 40.f}, {4});
|
||||
auto mask = array({true, false, false, true}, bool_);
|
||||
auto src = array({7.f, 8.f}, {2});
|
||||
|
||||
auto f_sum = [&mask](const std::vector<array>& xs) {
|
||||
return std::vector<array>{sum(masked_scatter(xs[0], mask, xs[1]))};
|
||||
};
|
||||
|
||||
auto v = vjp(f_sum, {self, src}, {array(1.f)});
|
||||
const auto& grads = v.second;
|
||||
|
||||
CHECK(array_equal(grads[0], array({0.f, 1.f, 1.f, 0.f}, {4})).item<bool>());
|
||||
CHECK(array_equal(grads[1], array({1.f, 1.f}, {2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user