Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
}
TEST_CASE("test eval with tracer") {
TEST_CASE("test eval with tracer when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1);
x.set_tracer(true);
// Ok, x is not a node
CHECK(!x.is_tracer());
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
x = ones({2, 3});
x.set_tracer(true);
CHECK_THROWS(eval(x));
// Ok retain_graph=true
eval({x}, true);
// Make sure all arguments are checked
auto y = ones({2, 3});
CHECK_THROWS(eval(x, y));
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
}
TEST_CASE("test eval graph retention") {
TEST_CASE("test eval graph retention when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1);
x.set_tracer(true);
auto y = array(2);
auto z = x + y;
eval({z}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(true), 3);
CHECK(z.has_primitive());
eval(z);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(), 3);
z.set_tracer(false);
CHECK_EQ(z.item<int>(), 3);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
@@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
z = x + y;
auto a = z + x;
auto b = a + y;
eval({b}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK(a.has_primitive());
CHECK(a.is_evaled());
eval({b}, false);
eval(b);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(!a.has_primitive());