mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -48,36 +48,36 @@ TEST_CASE("test eval multiple") {
|
||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval with tracer") {
|
||||
TEST_CASE("test eval with tracer when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
|
||||
// Ok, x is not a node
|
||||
CHECK(!x.is_tracer());
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
|
||||
x = ones({2, 3});
|
||||
x.set_tracer(true);
|
||||
CHECK_THROWS(eval(x));
|
||||
|
||||
// Ok retain_graph=true
|
||||
eval({x}, true);
|
||||
|
||||
// Make sure all arguments are checked
|
||||
auto y = ones({2, 3});
|
||||
CHECK_THROWS(eval(x, y));
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval graph retention") {
|
||||
TEST_CASE("test eval graph retention when not tracing") {
|
||||
// Since we are not tracing it doesn't matter that the array flags are
|
||||
// tracers they will always be detached.
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
auto y = array(2);
|
||||
auto z = x + y;
|
||||
eval({z}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(true), 3);
|
||||
CHECK(z.has_primitive());
|
||||
eval(z);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
|
||||
z.set_tracer(false);
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
@@ -85,13 +85,7 @@ TEST_CASE("test eval graph retention") {
|
||||
z = x + y;
|
||||
auto a = z + x;
|
||||
auto b = a + y;
|
||||
eval({b}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
|
||||
eval({b}, false);
|
||||
eval(b);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(!a.has_primitive());
|
||||
|
Reference in New Issue
Block a user