mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
@@ -49,7 +49,6 @@ TEST_CASE("test arg reduce small") {
|
||||
{0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5,
|
||||
0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5},
|
||||
{2, 3, 4});
|
||||
x.eval();
|
||||
test_arg_reduce_small(
|
||||
Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3});
|
||||
test_arg_reduce_small(
|
||||
|
@@ -100,7 +100,7 @@ TEST_CASE("test jvp") {
|
||||
auto fun1 = [](const array& x) {
|
||||
auto y = 3 * x;
|
||||
eval(y);
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(y.is_available());
|
||||
CHECK(y.has_primitive());
|
||||
CHECK(y.is_tracer());
|
||||
return 2 * y;
|
||||
@@ -253,7 +253,7 @@ TEST_CASE("test grad") {
|
||||
eval(y);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(!y.is_tracer());
|
||||
CHECK(y.is_evaled());
|
||||
CHECK(y.is_available());
|
||||
CHECK(!y.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
@@ -265,7 +265,7 @@ TEST_CASE("test grad") {
|
||||
x = x + 2.0f;
|
||||
eval(x);
|
||||
CHECK(x.is_tracer());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.is_available());
|
||||
CHECK(x.has_primitive());
|
||||
return square(x);
|
||||
};
|
||||
@@ -1259,7 +1259,7 @@ TEST_CASE("test update state") {
|
||||
grad(fn)(y);
|
||||
eval(state);
|
||||
CHECK(!state.has_primitive());
|
||||
CHECK(state.is_evaled());
|
||||
CHECK(state.is_available());
|
||||
CHECK(array_equal(state, array({1.0, 1.0})).item<bool>());
|
||||
}
|
||||
|
||||
|
@@ -56,13 +56,13 @@ TEST_CASE("test eval with tracer when not tracing") {
|
||||
CHECK(!x.is_tracer());
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.is_available());
|
||||
|
||||
x = ones({2, 3});
|
||||
x.set_tracer(true);
|
||||
eval(x);
|
||||
CHECK(!x.has_primitive());
|
||||
CHECK(x.is_evaled());
|
||||
CHECK(x.is_available());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval graph retention when not tracing") {
|
||||
@@ -74,20 +74,20 @@ TEST_CASE("test eval graph retention when not tracing") {
|
||||
auto z = x + y;
|
||||
eval(z);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(z.is_available());
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
|
||||
z.set_tracer(false);
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(z.is_available());
|
||||
|
||||
z = x + y;
|
||||
auto a = z + x;
|
||||
auto b = a + y;
|
||||
eval(b);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(z.is_available());
|
||||
CHECK(!a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
CHECK(a.is_available());
|
||||
}
|
||||
|
@@ -495,12 +495,14 @@ TEST_CASE("test metal memory info") {
|
||||
|
||||
// Query active and peak memory
|
||||
{
|
||||
auto a = zeros({4096});
|
||||
// Do these tests on the CPU since deallocation is synchronized
|
||||
// with the main thread.
|
||||
auto a = zeros({4096}, Device::cpu);
|
||||
eval(a);
|
||||
auto active_mem = metal::get_active_memory();
|
||||
CHECK(active_mem >= 4096 * 4);
|
||||
{
|
||||
auto b = zeros({4096});
|
||||
auto b = zeros({4096}, Device::cpu);
|
||||
eval(b);
|
||||
}
|
||||
auto new_active_mem = metal::get_active_memory();
|
||||
|
Reference in New Issue
Block a user