simplify tests use compile now

This commit is contained in:
Awni Hannun
2024-01-16 22:06:19 -08:00
parent 1c3f82ca17
commit ed4d867092
5 changed files with 75 additions and 52 deletions

View File

@@ -105,60 +105,98 @@ TEST_CASE("test enable and disable compile") {
CHECK_THROWS(compile(nullptr));
}
auto add_scalars(const std::vector<array>&) {
auto a = array(-1.0f);
auto b = array(-1.0f);
return std::vector<array>{abs(a), abs(b)};
};
auto max_scalars(const std::vector<array>&) {
auto a = array({-1.0f, 2.0f});
auto b = maximum(a, array(0.0f));
auto c = maximum(-a, array(0.0f));
auto d = b + c;
return std::vector<array>{b, c, d};
};
TEST_CASE("test simplify scalars") {
{
auto a = array(-1.0f);
auto b = array(-1.0f);
auto c = abs(a);
auto d = abs(b);
simplify({c, d});
auto cfun = compile(add_scalars);
auto out = cfun({});
auto c = out[0];
auto d = out[1];
CHECK(c.inputs()[0].id() == d.inputs()[0].id());
}
{
auto a = array({-1.0f, 2.0f});
auto b = maximum(a, array(0.0f));
auto c = maximum(-a, array(0.0f));
auto d = b + c;
simplify({d});
auto out = compile(max_scalars)({a});
auto b = out[0];
auto c = out[1];
auto d = out[2];
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
}
}
// TODO rework these tests for compile
/*TEST_CASE("test simplify") {
auto exp_two(const std::vector<array>& inputs) {
auto a = inputs[0];
return std::vector<array>{exp(a) + exp(a)};
};
TEST_CASE("test simplify") {
auto a = array({1.0f, 2.0f});
auto b = exp(a) + exp(a);
simplify(b);
auto b = compile(exp_two)({a})[0];
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
}
auto add_diff(const std::vector<array>& inputs) {
auto a = inputs[0];
return std::vector<array>{cos(a) + sin(a)};
};
TEST_CASE("test no simplify") {
auto a = array({1.0f, 2.0f});
auto b = cos(a) + sin(a);
simplify(b);
auto b = compile(add_diff)({a})[0];
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
}
auto multi_one(const std::vector<array>&) {
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = c[0] + d[0];
auto f = c[1] + d[1];
return std::vector<array>{e, f};
}
auto multi_two(const std::vector<array>&) {
auto a = array(1.0);
auto b = array(1.0);
auto c = divmod(a, b);
return std::vector<array>{c};
}
auto multi_three(const std::vector<array>&) {
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = stack({c[0], c[1], d[0], d[1]});
return std::vector<array>{e};
}
TEST_CASE("test simplify multi output") {
{
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = c[0] + d[0];
auto f = c[1] + d[1];
simplify({e, f});
auto out = compile(multi_one)({});
auto e = out[0];
auto f = out[1];
CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id());
CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id());
}
{
auto a = array(1.0);
auto b = array(1.0);
auto c = divmod(a, b);
simplify(c);
auto c = compile(multi_two)({});
CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id());
CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id());
CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id());
@@ -167,14 +205,9 @@ TEST_CASE("test simplify multi output") {
// Make sure the output order of multi-output primitives
// is respected in simplification
{
auto a = array(1.0);
auto b = array(2.0);
auto c = divmod(a, b);
auto d = divmod(a, b);
auto e = stack({c[0], c[1], d[0], d[1]});
simplify(e);
auto e = compile(multi_three)({})[0];
CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item<bool>());
CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id());
CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id());
}
}*/
}