mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:11:44 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") {
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap dynamic slices") {
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};
|
||||
};
|
||||
auto x = reshape(arange(12), {3, 4});
|
||||
auto out = vmap(fun)({x})[0];
|
||||
CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());
|
||||
|
||||
out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];
|
||||
CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
slice_update(inputs[0], inputs[1], array({1}), {0})};
|
||||
};
|
||||
auto x = zeros({2, 2});
|
||||
auto upd = ones({2, 1});
|
||||
|
||||
auto out = vmap(fun)({x, upd})[0];
|
||||
CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());
|
||||
|
||||
out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];
|
||||
CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user