Removes the retain_graph flag (#385)

* Adds global tracing flag
* Removes retain_graph in favor of is_tracer
This commit is contained in:
Angelos Katharopoulos
2024-01-07 15:16:51 -08:00
committed by GitHub
parent 449b43762e
commit a611b0bc82
22 changed files with 209 additions and 207 deletions

View File

@@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
CHECK_EQ(dout[0].item<float>(), 4.0f);
}
// Evaling in function without graph retention throws
// Evaling in function while tracing performs graph retention
{
auto fun = [](const array& x) {
auto y = 3 * x;
eval(y);
return 2 * y;
};
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
// Ok with graph retention
auto fun1 = [](const array& x) {
auto y = 3 * x;
eval({y}, true);
eval(y);
CHECK(y.is_evaled());
CHECK(y.has_primitive());
CHECK(y.is_tracer());
return 2 * y;
};
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
@@ -251,29 +246,27 @@ TEST_CASE("test grad") {
}
{
// Evaluating in the middle of the grad function throws
// as it breaks the graph
auto fn = [](array x) {
x = x + 2.0f;
eval(x);
return square(x);
};
CHECK_THROWS(grad(fn)(array(1.0f)));
// Ok since the output is independent of y
// No graph retention since the output is independent of y
auto y = ones({3, 3});
auto fn1 = [y](array x) {
x = x + 2.0f;
eval(y);
CHECK(x.is_tracer());
CHECK(!y.is_tracer());
CHECK(y.is_evaled());
CHECK(!y.has_primitive());
return square(x);
};
auto dfdx = grad(fn1)(array(1.0f));
CHECK_EQ(dfdx.item<float>(), 6.0f);
// Retain the graph to avoid breaking it
// Graph automatically retained to compute the grad
auto fn2 = [](array x) {
x = x + 2.0f;
eval({x}, true);
eval(x);
CHECK(x.is_tracer());
CHECK(x.is_evaled());
CHECK(x.has_primitive());
return square(x);
};
dfdx = grad(fn2)(array(1.0f));
@@ -283,7 +276,8 @@ TEST_CASE("test grad") {
// Control flow in grad computation
{
auto fn = [](array x) {
if (x.item<float>(true) > 1) {
x = x + array(2.0f);
if (x.item<float>() > 3) {
return square(x);
} else {
return 4 * x;
@@ -294,7 +288,7 @@ TEST_CASE("test grad") {
CHECK_EQ(dfdx.item<float>(), 4.0f);
dfdx = grad(fn)(array(1.5f));
CHECK_EQ(dfdx.item<float>(), 3.0f);
CHECK_EQ(dfdx.item<float>(), 7.0f);
}
// Grad with multiple inputs
@@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
CHECK(array_equal(out, expected).item<bool>());
}
}
TEST_CASE("test update state") {
auto y = array({1.0});
auto x = array({1.0, 1.0});
auto state = array({0.0, 0.0});
auto fn = [&state, &x](array y) {
x = y * x;
state = state + x;
return sum(x);
};
grad(fn)(y);
eval(state);
CHECK(!state.has_primitive());
CHECK(state.is_evaled());
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
}