Split multi output (#461)

* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
This commit is contained in:
Angelos Katharopoulos
2024-01-16 13:33:55 -08:00
committed by GitHub
parent 4e290d282f
commit d8fabaa12b
12 changed files with 202 additions and 5 deletions

View File

@@ -922,6 +922,35 @@ TEST_CASE("test concatenate grads") {
array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item<bool>());
}
TEST_CASE("test split grads") {
array x = arange(6, float32);
eval(x);
{
auto fn = [](const array& x) {
auto parts = split(x, 3);
return parts[0] * parts[1] + parts[2];
};
auto out = vjp(fn, {x}, {ones({2})}).second;
CHECK_EQ(out.size(), 6);
CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f}))
.item<bool>());
}
{
auto fn = [](const array& x) {
auto parts = split(x, 3);
return parts[0] * parts[2];
};
auto out = vjp(fn, {x}, {ones({2})}).second;
CHECK_EQ(out.size(), 6);
CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f}))
.item<bool>());
}
}
TEST_CASE("test comparison grads") {
auto x = ones({3, 1});
auto y = zeros({1, 3});