Shared events for synchronization + async eval (#998)

* more async eval

* fix rebase

* try correct async eval

* fix async

* more tests for async eval

* use shared events for synchronization

* comment + cleanup

* with autorelease pool

* fix no metal build

* fix compile

* fix patch

* don't eval if asyn evale'd

* don't use is_evaled

* comments

* more multi stream tests

* try and cleanup use of is_evaled

* use a status flag
This commit is contained in:
Awni Hannun
2024-04-17 06:16:02 -07:00
committed by GitHub
parent b18468bf81
commit 8a0677d56d
28 changed files with 424 additions and 125 deletions

View File

@@ -49,7 +49,6 @@ TEST_CASE("test arg reduce small") {
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
{2, 3, 4});
x.eval();
test_arg_reduce_small(
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
test_arg_reduce_small(

View File

@@ -100,7 +100,7 @@ TEST_CASE("test jvp") {
auto fun1 = [](const array& x) {
auto y = 3 * x;
eval(y);
CHECK(y.is_evaled());
CHECK(y.is_available());
CHECK(y.has_primitive());
CHECK(y.is_tracer());
return 2 * y;
@@ -253,7 +253,7 @@ TEST_CASE("test grad") {
eval(y);
CHECK(x.is_tracer());
CHECK(!y.is_tracer());
CHECK(y.is_evaled());
CHECK(y.is_available());
CHECK(!y.has_primitive());
return square(x);
};
@@ -265,7 +265,7 @@ TEST_CASE("test grad") {
x = x + 2.0f;
eval(x);
CHECK(x.is_tracer());
CHECK(x.is_evaled());
CHECK(x.is_available());
CHECK(x.has_primitive());
return square(x);
};
@@ -1259,7 +1259,7 @@ TEST_CASE("test update state") {
grad(fn)(y);
eval(state);
CHECK(!state.has_primitive());
CHECK(state.is_evaled());
CHECK(state.is_available());
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
}

View File

@@ -56,13 +56,13 @@ TEST_CASE("test eval with tracer when not tracing") {
CHECK(!x.is_tracer());
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
CHECK(x.is_available());
x = ones({2, 3});
x.set_tracer(true);
eval(x);
CHECK(!x.has_primitive());
CHECK(x.is_evaled());
CHECK(x.is_available());
}
TEST_CASE("test eval graph retention when not tracing") {
@@ -74,20 +74,20 @@ TEST_CASE("test eval graph retention when not tracing") {
auto z = x + y;
eval(z);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(z.is_available());
CHECK_EQ(z.item<int>(), 3);
z.set_tracer(false);
CHECK_EQ(z.item<int>(), 3);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(z.is_available());
z = x + y;
auto a = z + x;
auto b = a + y;
eval(b);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(z.is_available());
CHECK(!a.has_primitive());
CHECK(a.is_evaled());
CHECK(a.is_available());
}

View File

@@ -495,12 +495,14 @@ TEST_CASE("test metal memory info") {
// Query active and peak memory
{
auto a = zeros({4096});
// Do these tests on the CPU since deallocation is synchronized
// with the main thread.
auto a = zeros({4096}, Device::cpu);
eval(a);
auto active_mem = metal::get_active_memory();
CHECK(active_mem >= 4096 * 4);
{
auto b = zeros({4096});
auto b = zeros({4096}, Device::cpu);
eval(b);
}
auto new_active_mem = metal::get_active_memory();