mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
Dynamic slicing (#1741)
* dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit
This commit is contained in:
@@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test grad dynamic slices") {
|
||||
{
|
||||
auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };
|
||||
auto x = array({1, 2, 3, 4}, {2, 2});
|
||||
auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;
|
||||
CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());
|
||||
}
|
||||
{
|
||||
auto fn = [](const std::vector<array>& inputs) {
|
||||
const auto& x = inputs[0];
|
||||
const auto& update = inputs[1];
|
||||
return std::vector<array>{slice_update(x, update, array({0}), {0})};
|
||||
};
|
||||
auto x = zeros({2, 2});
|
||||
auto update = array({3.f, 4.f}, {1, 2});
|
||||
auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;
|
||||
CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());
|
||||
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ TEST_CASE("test QR factorization") {
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));
|
||||
|
||||
array A = array({{2., 3., 1., 2.}, {2, 2}});
|
||||
array A = array({2., 3., 1., 2.}, {2, 2});
|
||||
auto [Q, R] = linalg::qr(A, Device::cpu);
|
||||
auto out = matmul(Q, R);
|
||||
CHECK(allclose(out, A).item<bool>());
|
||||
|
||||
@@ -353,6 +353,50 @@ TEST_CASE("test slice update") {
|
||||
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test dynamic slice") {
|
||||
auto src = reshape(arange(6), {2, 3});
|
||||
CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
|
||||
CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
|
||||
|
||||
CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
|
||||
CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
|
||||
|
||||
auto out = slice(src, array({1}), {0}, {1, 2});
|
||||
auto expected = array({3, 4}, {1, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
out = slice(src, array({1, 1}), {0, 1}, {1, 2});
|
||||
expected = array({4, 5}, {1, 2});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test dynamic slice update") {
|
||||
auto src = zeros({2, 3}, int32);
|
||||
auto upd = ones({1, 2}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
|
||||
|
||||
upd = ones({4}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
|
||||
upd = ones({1, 4}, int32);
|
||||
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
|
||||
CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
|
||||
|
||||
upd = ones({1, 2}, int32);
|
||||
auto out = slice_update(src, upd, array({1}), {0});
|
||||
auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
|
||||
upd = ones({1, 2}, int32);
|
||||
out = slice_update(src, upd, array({1, 1}), {0, 1});
|
||||
expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test split") {
|
||||
array x = array(1);
|
||||
CHECK_THROWS(split(x, 0));
|
||||
@@ -720,7 +764,7 @@ TEST_CASE("test is inf") {
|
||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||
|
||||
array w = array({1.0f, inf, 2.0f});
|
||||
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
|
||||
CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
|
||||
|
||||
array a(1.0f, bfloat16);
|
||||
CHECK_FALSE(isinf(a).item<bool>());
|
||||
|
||||
@@ -686,7 +686,7 @@ TEST_CASE("test laplace") {
|
||||
CHECK(std::abs(sample_variance - expected_variance) < 0.01);
|
||||
|
||||
// Expected kurtosis of Laplace distribution is 3.
|
||||
array fourth_pows = power(out - sample_mean, {4});
|
||||
array fourth_pows = power(out - sample_mean, array(4));
|
||||
float sample_kurtosis =
|
||||
mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;
|
||||
float expected_kurtosis = 3.0;
|
||||
|
||||
@@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") {
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap dynamic slices") {
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};
|
||||
};
|
||||
auto x = reshape(arange(12), {3, 4});
|
||||
auto out = vmap(fun)({x})[0];
|
||||
CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());
|
||||
|
||||
out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];
|
||||
CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
slice_update(inputs[0], inputs[1], array({1}), {0})};
|
||||
};
|
||||
auto x = zeros({2, 2});
|
||||
auto upd = ones({2, 1});
|
||||
|
||||
auto out = vmap(fun)({x, upd})[0];
|
||||
CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());
|
||||
|
||||
out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];
|
||||
CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user