mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -95,19 +95,14 @@ TEST_CASE("test jvp") {
|
||||
CHECK_EQ(dout[0].item<float>(), 4.0f);
|
||||
}
|
||||
|
||||
// Evaling in function without graph retention throws
|
||||
// Evaling in function while tracing performs graph retention
|
||||
{
|
||||
auto fun = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval(y);
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_THROWS(jvp(fun, array(1.0f), array(1.0f)));
|
||||
|
||||
// Ok with graph retention
|
||||
auto fun1 = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval({y}, true);
|
||||
eval(y);
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(y.has_primitive());
|
||||
CHECK(y.is_tracer());
|
||||
return 2 * y;
|
||||
};
|
||||
CHECK_EQ(jvp(fun1, array(1.0f), array(1.0f)).second.item<float>(), 6.0f);
|
||||
@@ -251,29 +246,27 @@ TEST_CASE("test grad") {
|
||||
}
|
||||
|
||||
{
|
||||
// Evaluating in the middle of the grad function throws
|
||||
// as it breaks the graph
|
||||
auto fn = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(x);
|
||||
return square(x);
|
||||
};
|
||||
CHECK_THROWS(grad(fn)(array(1.0f)));
|
||||
|
||||
// Ok since the output is independent of y
|
||||
// No graph retention since the output is independent of y
|
||||
auto y = ones({3, 3});
|
||||
auto fn1 = [y](array x) {
|
||||
x = x + 2.0f;
|
||||
eval(y);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(!y.is_tracer());
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(!y.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
auto dfdx = grad(fn1)(array(1.0f));
|
||||
CHECK_EQ(dfdx.item<float>(), 6.0f);
|
||||
|
||||
// Retain the graph to avoid breaking it
|
||||
// Graph automatically retained to compute the grad
|
||||
auto fn2 = [](array x) {
|
||||
x = x + 2.0f;
|
||||
eval({x}, true);
|
||||
eval(x);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
dfdx = grad(fn2)(array(1.0f));
|
||||
@@ -283,7 +276,8 @@ TEST_CASE("test grad") {
|
||||
// Control flow in grad computation
|
||||
{
|
||||
auto fn = [](array x) {
|
||||
if (x.item<float>(true) > 1) {
|
||||
x = x + array(2.0f);
|
||||
if (x.item<float>() > 3) {
|
||||
return square(x);
|
||||
} else {
|
||||
return 4 * x;
|
||||
@@ -294,7 +288,7 @@ TEST_CASE("test grad") {
|
||||
CHECK_EQ(dfdx.item<float>(), 4.0f);
|
||||
|
||||
dfdx = grad(fn)(array(1.5f));
|
||||
CHECK_EQ(dfdx.item<float>(), 3.0f);
|
||||
CHECK_EQ(dfdx.item<float>(), 7.0f);
|
||||
}
|
||||
|
||||
// Grad with multiple inputs
|
||||
@@ -1192,3 +1186,19 @@ TEST_CASE("test scan grads") {
|
||||
CHECK(array_equal(out, expected).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test update state") {
|
||||
auto y = array({1.0});
|
||||
auto x = array({1.0, 1.0});
|
||||
auto state = array({0.0, 0.0});
|
||||
auto fn = [&state, &x](array y) {
|
||||
x = y * x;
|
||||
state = state + x;
|
||||
return sum(x);
|
||||
};
|
||||
grad(fn)(y);
|
||||
eval(state);
|
||||
CHECK(!state.has_primitive());
|
||||
CHECK(state.is_evaled());
|
||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user