mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -353,6 +353,50 @@ TEST_CASE("test slice update") {
|
||||
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test dynamic slice") {
|
||||
auto src = reshape(arange(6), {2, 3});
|
||||
CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
|
||||
|
||||
CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
|
||||
CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
|
||||
|
||||
auto out = slice(src, array({1}), {0}, {1, 2});
|
||||
auto expected = array({3, 4}, {1, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
out = slice(src, array({1, 1}), {0, 1}, {1, 2});
|
||||
expected = array({4, 5}, {1, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test dynamic slice update") {
|
||||
auto src = zeros({2, 3}, int32);
|
||||
auto upd = ones({1, 2}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
|
||||
|
||||
upd = ones({4}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
|
||||
upd = ones({1, 4}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
|
||||
|
||||
upd = ones({1, 2}, int32);
|
||||
auto out = slice_update(src, upd, array({1}), {0});
|
||||
auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
upd = ones({1, 2}, int32);
|
||||
out = slice_update(src, upd, array({1, 1}), {0, 1});
|
||||
expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test split") {
|
||||
array x = array(1);
|
||||
CHECK_THROWS(split(x, 0));
|
||||
@@ -720,7 +764,7 @@ TEST_CASE("test is inf") {
|
||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||
|
||||
array w = array({1.0f, inf, 2.0f});
|
||||
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
|
||||
CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
|
||||
|
||||
array a(1.0f, bfloat16);
|
||||
CHECK_FALSE(isinf(a).item<bool>());
|
||||
|
Reference in New Issue
Block a user