mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
Removes the retain_graph flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:
committed by
GitHub
parent
449b43762e
commit
a611b0bc82
@@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
|
||||
CHECK_EQ(dout[0].item<float>(), 4.0f);
|
||||
}
|
||||
|
||||
// Evaling in function without graph retention throws
|
||||
// Evaling in function while tracing performs graph retention
|
||||
{
|
||||
auto fun = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval(y);
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
|
||||
|
||||
// Ok with graph retention
|
||||
auto fun1 = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval({y}, true);
|
||||
eval(y);
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(y.has_primitive());
|
||||
CHECK(y.is_tracer());
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
|
||||
@@ -251,29 +246,27 @@ TEST_CASE("test grad") {
|
||||
}
|
||||
|
||||
{
|
||||
// Evaluating in the middle of the grad function throws
|
||||
// as it breaks the graph
|
||||
auto fn = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(x);
|
||||
return square(x);
|
||||
};
|
||||
CHECK_THROWS(grad(fn)(array(1.0f)));
|
||||
|
||||
// Ok since the output is independent of y
|
||||
// No graph retention since the output is independent of y
|
||||
auto y = ones({3, 3});
|
||||
auto fn1 = [y](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(y);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(!y.is_tracer());
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(!y.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
auto dfdx = grad(fn1)(array(1.0f));
|
||||
CHECK_EQ(dfdx.item<float>(), 6.0f);
|
||||
|
||||
// Retain the graph to avoid breaking it
|
||||
// Graph automatically retained to compute the grad
|
||||
auto fn2 = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval({x}, true);
|
||||
eval(x);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
dfdx = grad(fn2)(array(1.0f));
|
||||
@@ -283,7 +276,8 @@ TEST_CASE("test grad") {
|
||||
// Control flow in grad computation
|
||||
{
|
||||
auto fn = [](array x) {
|
||||
if (x.item<float>(true) > 1) {
|
||||
x = x + array(2.0f);
|
||||
if (x.item<float>() > 3) {
|
||||
return square(x);
|
||||
} else {
|
||||
return 4 * x;
|
||||
@@ -294,7 +288,7 @@ TEST_CASE("test grad") {
|
||||
CHECK_EQ(dfdx.item<float>(), 4.0f);
|
||||
|
||||
dfdx = grad(fn)(array(1.5f));
|
||||
CHECK_EQ(dfdx.item<float>(), 3.0f);
|
||||
CHECK_EQ(dfdx.item<float>(), 7.0f);
|
||||
}
|
||||
|
||||
// Grad with multiple inputs
|
||||
@@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test update state") {
|
||||
auto y = array({1.0});
|
||||
auto x = array({1.0, 1.0});
|
||||
auto state = array({0.0, 0.0});
|
||||
auto fn = [&state, &x](array y) {
|
||||
x = y * x;
|
||||
state = state + x;
|
||||
return sum(x);
|
||||
};
|
||||
grad(fn)(y);
|
||||
eval(state);
|
||||
CHECK(!state.has_primitive());
|
||||
CHECK(state.is_evaled());
|
||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||
}
|
||||
|
||||
@@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
|
||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval with tracer") {
|
||||
TEST_CASE("test eval with tracer when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
|
||||
// Ok, x is not a node
|
||||
CHECK(!x.is_tracer());
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
|
||||
x = ones({2, 3});
|
||||
x.set_tracer(true);
|
||||
CHECK_THROWS(eval(x));
|
||||
|
||||
// Ok retain_graph=true
|
||||
eval({x}, true);
|
||||
|
||||
// Make sure all arguments are checked
|
||||
auto y = ones({2, 3});
|
||||
CHECK_THROWS(eval(x, y));
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval graph retention") {
|
||||
TEST_CASE("test eval graph retention when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
auto y = array(2);
|
||||
auto z = x + y;
|
||||
eval({z}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(true), 3);
|
||||
CHECK(z.has_primitive());
|
||||
eval(z);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
|
||||
z.set_tracer(false);
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
@@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
|
||||
z = x + y;
|
||||
auto a = z + x;
|
||||
auto b = a + y;
|
||||
eval({b}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
|
||||
eval({b}, false);
|
||||
eval(b);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(!a.has_primitive());
|
||||
|
||||
@@ -183,7 +183,7 @@ TEST_CASE("test vmap with eval") {
|
||||
auto fun2 = [](std::vector<array> inputs) {
|
||||
auto x = inputs[0] + 1;
|
||||
auto y = inputs[1] + 2;
|
||||
eval({x}, true);
|
||||
eval(x);
|
||||
auto out = add(x, y);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user