Dynamic slicing (#1741)

* dynamic slice and slice update

* python bindings + tests + fix set item

* fix compile issue

* comment

* fix jit
This commit is contained in:
Awni Hannun
2025-01-07 14:02:16 -08:00
committed by GitHub
parent c9c81d0584
commit 516ded618b
27 changed files with 941 additions and 75 deletions

View File

@@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") {
}
}
}
TEST_CASE("test grad dynamic slices") {
{
auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); };
auto x = array({1, 2, 3, 4}, {2, 2});
auto out = vjp(fn, x, array({1, 1}, {1, 2})).second;
CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item<bool>());
}
{
auto fn = [](const std::vector<array>& inputs) {
const auto& x = inputs[0];
const auto& update = inputs[1];
return std::vector<array>{slice_update(x, update, array({0}), {0})};
};
auto x = zeros({2, 2});
auto update = array({3.f, 4.f}, {1, 2});
auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second;
CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item<bool>());
CHECK(allclose(outs[1], ones({1, 2})).item<bool>());
}
}

View File

@@ -250,7 +250,7 @@ TEST_CASE("test QR factorization") {
// Unsupported types throw
CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));
array A = array({{2., 3., 1., 2.}, {2, 2}});
array A = array({2., 3., 1., 2.}, {2, 2});
auto [Q, R] = linalg::qr(A, Device::cpu);
auto out = matmul(Q, R);
CHECK(allclose(out, A).item<bool>());

View File

@@ -353,6 +353,50 @@ TEST_CASE("test slice update") {
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
}
TEST_CASE("test dynamic slice") {
auto src = reshape(arange(6), {2, 3});
CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
auto out = slice(src, array({1}), {0}, {1, 2});
auto expected = array({3, 4}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
out = slice(src, array({1, 1}), {0, 1}, {1, 2});
expected = array({4, 5}, {1, 2});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test dynamic slice update") {
auto src = zeros({2, 3}, int32);
auto upd = ones({1, 2}, int32);
CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
upd = ones({4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
upd = ones({1, 4}, int32);
CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
upd = ones({1, 2}, int32);
auto out = slice_update(src, upd, array({1}), {0});
auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
upd = ones({1, 2}, int32);
out = slice_update(src, upd, array({1, 1}), {0, 1});
expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
CHECK(array_equal(out, expected).item<bool>());
}
TEST_CASE("test split") {
array x = array(1);
CHECK_THROWS(split(x, 0));
@@ -720,7 +764,7 @@ TEST_CASE("test is inf") {
CHECK_FALSE(any(isinf(z)).item<bool>());
array w = array({1.0f, inf, 2.0f});
CHECK(array_equal({false, true, false}, isinf(w)).item<bool>());
CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
array a(1.0f, bfloat16);
CHECK_FALSE(isinf(a).item<bool>());

View File

@@ -686,7 +686,7 @@ TEST_CASE("test laplace") {
CHECK(std::abs(sample_variance - expected_variance) < 0.01);
// Expected kurtosis of Laplace distribution is 3.
array fourth_pows = power(out - sample_mean, {4});
array fourth_pows = power(out - sample_mean, array(4));
float sample_kurtosis =
mean(fourth_pows).item<float>() / std::pow(sample_variance, 2) - 3;
float expected_kurtosis = 3.0;

View File

@@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") {
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
}
}
TEST_CASE("test vmap dynamic slices") {
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{slice(inputs[0], array({1}), {0}, {2})};
};
auto x = reshape(arange(12), {3, 4});
auto out = vmap(fun)({x})[0];
CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0];
CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4}))
.item<bool>());
}
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{
slice_update(inputs[0], inputs[1], array({1}), {0})};
};
auto x = zeros({2, 2});
auto upd = ones({2, 1});
auto out = vmap(fun)({x, upd})[0];
CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item<bool>());
out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0];
CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item<bool>());
}
}