mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test grad dynamic slices") {
|
||||
{
|
||||
auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };
|
||||
auto x = array({1, 2, 3, 4}, {2, 2});
|
||||
auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;
|
||||
CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());
|
||||
}
|
||||
{
|
||||
auto fn = [](const std::vector<array>& inputs) {
|
||||
const auto& x = inputs[0];
|
||||
const auto& update = inputs[1];
|
||||
return std::vector<array>{slice_update(x, update, array({0}), {0})};
|
||||
};
|
||||
auto x = zeros({2, 2});
|
||||
auto update = array({3.f, 4.f}, {1, 2});
|
||||
auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;
|
||||
CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());
|
||||
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user