Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
CHECK_EQ(dout[0].item<float>(), 4.0f);
}
// Evaling in function without graph retention throws
// Evaling in function while tracing performs graph retention
{
auto fun = [](const array& x) {
auto y = 3 * x;
eval(y);
return 2 * y;
};
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
// Ok with graph retention
auto fun1 = [](const array& x) {
auto y = 3 * x;
eval({y}, true);
eval(y);
CHECK(y.is_evaled());
CHECK(y.has_primitive());
CHECK(y.is_tracer());
return 2 * y;
};
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
@@ -251,29 +246,27 @@ TEST_CASE("test grad") {
}
{
// Evaluating in the middle of the grad function throws
// as it breaks the graph
auto fn = [](array x) {
x = x + 2.0f;
eval(x);
return square(x);
};
CHECK_THROWS(grad(fn)(array(1.0f)));
// Ok since the output is independent of y
// No graph retention since the output is independent of y
auto y = ones({3, 3});
auto fn1 = [y](array x) {
x = x + 2.0f;
eval(y);
CHECK(x.is_tracer());
CHECK(!y.is_tracer());
CHECK(y.is_evaled());
CHECK(!y.has_primitive());
return square(x);
};
auto dfdx = grad(fn1)(array(1.0f));
CHECK_EQ(dfdx.item<float>(), 6.0f);
// Retain the graph to avoid breaking it
// Graph automatically retained to compute the grad
auto fn2 = [](array x) {
x = x + 2.0f;
eval({x}, true);
eval(x);
CHECK(x.is_tracer());
CHECK(x.is_evaled());
CHECK(x.has_primitive());
return square(x);
};
dfdx = grad(fn2)(array(1.0f));
@@ -283,7 +276,8 @@ TEST_CASE("test grad") {
// Control flow in grad computation
{
auto fn = [](array x) {
if (x.item<float>(true) > 1) {
x = x + array(2.0f);
if (x.item<float>() > 3) {
return square(x);
} else {
return 4 * x;
@@ -294,7 +288,7 @@ TEST_CASE("test grad") {
CHECK_EQ(dfdx.item<float>(), 4.0f);
dfdx = grad(fn)(array(1.5f));
CHECK_EQ(dfdx.item<float>(), 3.0f);
CHECK_EQ(dfdx.item<float>(), 7.0f);
}
// Grad with multiple inputs
@@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test update state") {
auto y = array({1.0});
auto x = array({1.0, 1.0});
auto state = array({0.0, 0.0});
auto fn = [&state, &x](array y) {
x = y * x;
state = state + x;
return sum(x);
};
grad(fn)(y);
eval(state);
CHECK(!state.has_primitive());
CHECK(state.is_evaled());
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
}

View File

@@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
}
TEST_CASE("test eval with tracer") {
TEST_CASE("test eval with tracer when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1);
x.set_tracer(true);
// Ok, x is not a node
CHECK(!x.is_tracer());
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
x = ones({2, 3});
x.set_tracer(true);
CHECK_THROWS(eval(x));
// Ok retain_graph=true
eval({x}, true);
// Make sure all arguments are checked
auto y = ones({2, 3});
CHECK_THROWS(eval(x, y));
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
}
TEST_CASE("test eval graph retention") {
TEST_CASE("test eval graph retention when not tracing") {
// Since we are not tracing it doesn't matter that the array flags are
// tracers they will always be detached.
auto x = array(1);
x.set_tracer(true);
auto y = array(2);
auto z = x + y;
eval({z}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(true), 3);
CHECK(z.has_primitive());
eval(z);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(), 3);
z.set_tracer(false);
CHECK_EQ(z.item<int>(), 3);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
@@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
z = x + y;
auto a = z + x;
auto b = a + y;
eval({b}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK(a.has_primitive());
CHECK(a.is_evaled());
eval({b}, false);
eval(b);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(!a.has_primitive());

View File

@@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") {
auto fun2 = [](std::vector<array> inputs) {
auto x = inputs[0] + 1;
auto y = inputs[1] + 2;
eval({x}, true);
eval(x);
auto out = add(x, y);
return std::vector<array>{out};
};