Multi output primitives (#330)

* Multi-output primitives

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-08 16:39:08 -08:00
committed by GitHub
parent f45f70f133
commit f099ebe535
26 changed files with 2313 additions and 1039 deletions

View File

@@ -2397,3 +2397,58 @@ TEST_CASE("inner") {
expected = array({7., 0., 0., 7.}, {2, 2});
CHECK(array_equal(z, expected).item<bool>());
}
TEST_CASE("test divmod") {
auto x = array({1, 2, 3});
auto y = array({1, 1, 1});
auto out = divmod(x, y);
CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());
CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());
x = array({5, 6, 7});
y = array({2, 2, 2});
out = divmod(x, y);
CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());
CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());
// Siblings should be gone after evaling the graph
CHECK(out[0].siblings().empty());
CHECK(out[1].siblings().empty());
x = array({5.0, 6.0, 7.0});
y = array({2.0, 2.0, 2.0});
out = divmod(x, y);
CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());
CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());
x = array({1.0}, complex64);
y = array({2.0}, complex64);
CHECK_THROWS(divmod(x, y));
// Check that we can eval on both outputs
x = array({1.0});
y = array({2.0});
out = divmod(x, y);
eval(out);
CHECK_EQ(out[0].item<float>(), 0.0);
CHECK_EQ(out[1].item<float>(), 1.0);
// Check nested in the graph
x = array({1.0});
y = array({2.0});
out = divmod(x, y);
auto z = out[0] + out[1];
CHECK_EQ(z.item<float>(), 1.0);
// Check that we can still eval when one output goes out of scope
std::vector<array> out_holder;
{ out_holder.push_back(divmod(x, y)[0]); }
eval(out_holder);
CHECK_EQ(out_holder[0].item<float>(), 0.0);
// Check that we can still eval when the other output goes out of scope
out_holder.clear();
{ out_holder.push_back(divmod(x, y)[1]); }
eval(out_holder);
CHECK_EQ(out_holder[0].item<float>(), 1.0);
}