mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Split multi output (#461)
* Multi-output split primitive * Add the multi-output split to the ArrayIterator * Add some grad tests for split
This commit is contained in:

committed by
GitHub

parent
4e290d282f
commit
d8fabaa12b
@@ -922,6 +922,35 @@ TEST_CASE("test concatenate grads") {
|
||||
array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test split grads") {
|
||||
array x = arange(6, float32);
|
||||
eval(x);
|
||||
|
||||
{
|
||||
auto fn = [](const array& x) {
|
||||
auto parts = split(x, 3);
|
||||
return parts[0] * parts[1] + parts[2];
|
||||
};
|
||||
auto out = vjp(fn, {x}, {ones({2})}).second;
|
||||
|
||||
CHECK_EQ(out.size(), 6);
|
||||
CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto fn = [](const array& x) {
|
||||
auto parts = split(x, 3);
|
||||
return parts[0] * parts[2];
|
||||
};
|
||||
auto out = vjp(fn, {x}, {ones({2})}).second;
|
||||
|
||||
CHECK_EQ(out.size(), 6);
|
||||
CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f}))
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test comparison grads") {
|
||||
auto x = ones({3, 1});
|
||||
auto y = zeros({1, 3});
|
||||
|
Reference in New Issue
Block a user