mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	 8a0677d56d
			
		
	
	8a0677d56d
	
	
	
		
			
			* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
		
			
				
	
	
		
			94 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			94 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright © 2023 Apple Inc.
 | |
| 
 | |
| #include "doctest/doctest.h"
 | |
| 
 | |
| #include "mlx/mlx.h"
 | |
| 
 | |
| using namespace mlx::core;
 | |
| 
 | |
| TEST_CASE("test eval") {
 | |
|   {
 | |
|     array x(1.0);
 | |
|     array y(1);
 | |
|     array z(true);
 | |
|     eval({x, y, z});
 | |
|     CHECK_EQ(x.item<float>(), 1.0);
 | |
|   }
 | |
| 
 | |
|   {
 | |
|     array x(1.0);
 | |
|     array y = ones({2, 2});
 | |
|     array z(true);
 | |
|     eval({x, y, z});
 | |
|     CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST_CASE("test eval multiple") {
 | |
|   auto x = ones({10, 10});
 | |
|   auto y = ones({10, 10});
 | |
|   eval({x, y});
 | |
|   CHECK(array_equal(x, y).item<bool>());
 | |
| 
 | |
|   auto a = x + y;
 | |
|   auto b = x - y;
 | |
|   eval({a, b});
 | |
|   CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
 | |
|   CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
 | |
| 
 | |
|   x = ones({10, 10});
 | |
|   y = ones({10, 10});
 | |
|   eval(x, y);
 | |
|   CHECK(array_equal(x, y).item<bool>());
 | |
| 
 | |
|   a = x + y;
 | |
|   b = x - y;
 | |
|   eval(a, b);
 | |
|   CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
 | |
|   CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
 | |
| }
 | |
| 
 | |
| TEST_CASE("test eval with tracer when not tracing") {
 | |
|   // Since we are not tracing it doesn't matter that the array flags are
 | |
|   // tracers they will always be detached.
 | |
|   auto x = array(1);
 | |
|   x.set_tracer(true);
 | |
|   CHECK(!x.is_tracer());
 | |
|   eval(x);
 | |
|   CHECK(!x.has_primitive());
 | |
|   CHECK(x.is_available());
 | |
| 
 | |
|   x = ones({2, 3});
 | |
|   x.set_tracer(true);
 | |
|   eval(x);
 | |
|   CHECK(!x.has_primitive());
 | |
|   CHECK(x.is_available());
 | |
| }
 | |
| 
 | |
| TEST_CASE("test eval graph retention when not tracing") {
 | |
|   // Since we are not tracing it doesn't matter that the array flags are
 | |
|   // tracers they will always be detached.
 | |
|   auto x = array(1);
 | |
|   x.set_tracer(true);
 | |
|   auto y = array(2);
 | |
|   auto z = x + y;
 | |
|   eval(z);
 | |
|   CHECK(!z.has_primitive());
 | |
|   CHECK(z.is_available());
 | |
|   CHECK_EQ(z.item<int>(), 3);
 | |
| 
 | |
|   z.set_tracer(false);
 | |
|   CHECK_EQ(z.item<int>(), 3);
 | |
|   CHECK(!z.has_primitive());
 | |
|   CHECK(z.is_available());
 | |
| 
 | |
|   z = x + y;
 | |
|   auto a = z + x;
 | |
|   auto b = a + y;
 | |
|   eval(b);
 | |
|   CHECK(!z.has_primitive());
 | |
|   CHECK(z.is_available());
 | |
|   CHECK(!a.has_primitive());
 | |
|   CHECK(a.is_available());
 | |
| }
 |