mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00

* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
94 lines
2.0 KiB
C++
94 lines
2.0 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#include "doctest/doctest.h"
|
|
|
|
#include "mlx/mlx.h"
|
|
|
|
using namespace mlx::core;
|
|
|
|
TEST_CASE("test eval") {
|
|
{
|
|
array x(1.0);
|
|
array y(1);
|
|
array z(true);
|
|
eval({x, y, z});
|
|
CHECK_EQ(x.item<float>(), 1.0);
|
|
}
|
|
|
|
{
|
|
array x(1.0);
|
|
array y = ones({2, 2});
|
|
array z(true);
|
|
eval({x, y, z});
|
|
CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
|
|
}
|
|
}
|
|
|
|
TEST_CASE("test eval multiple") {
|
|
auto x = ones({10, 10});
|
|
auto y = ones({10, 10});
|
|
eval({x, y});
|
|
CHECK(array_equal(x, y).item<bool>());
|
|
|
|
auto a = x + y;
|
|
auto b = x - y;
|
|
eval({a, b});
|
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
|
|
|
x = ones({10, 10});
|
|
y = ones({10, 10});
|
|
eval(x, y);
|
|
CHECK(array_equal(x, y).item<bool>());
|
|
|
|
a = x + y;
|
|
b = x - y;
|
|
eval(a, b);
|
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
|
}
|
|
|
|
TEST_CASE("test eval with tracer when not tracing") {
|
|
// Since we are not tracing it doesn't matter that the array flags are
|
|
// tracers they will always be detached.
|
|
auto x = array(1);
|
|
x.set_tracer(true);
|
|
CHECK(!x.is_tracer());
|
|
eval(x);
|
|
CHECK(!x.has_primitive());
|
|
CHECK(x.is_available());
|
|
|
|
x = ones({2, 3});
|
|
x.set_tracer(true);
|
|
eval(x);
|
|
CHECK(!x.has_primitive());
|
|
CHECK(x.is_available());
|
|
}
|
|
|
|
TEST_CASE("test eval graph retention when not tracing") {
|
|
// Since we are not tracing it doesn't matter that the array flags are
|
|
// tracers they will always be detached.
|
|
auto x = array(1);
|
|
x.set_tracer(true);
|
|
auto y = array(2);
|
|
auto z = x + y;
|
|
eval(z);
|
|
CHECK(!z.has_primitive());
|
|
CHECK(z.is_available());
|
|
CHECK_EQ(z.item<int>(), 3);
|
|
|
|
z.set_tracer(false);
|
|
CHECK_EQ(z.item<int>(), 3);
|
|
CHECK(!z.has_primitive());
|
|
CHECK(z.is_available());
|
|
|
|
z = x + y;
|
|
auto a = z + x;
|
|
auto b = a + y;
|
|
eval(b);
|
|
CHECK(!z.has_primitive());
|
|
CHECK(z.is_available());
|
|
CHECK(!a.has_primitive());
|
|
CHECK(a.is_available());
|
|
}
|