| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | // Copyright © 2023 Apple Inc.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | #include "doctest/doctest.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "mlx/mlx.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace mlx::core; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test eval") { | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     array x(1.0); | 
					
						
							|  |  |  |     array y(1); | 
					
						
							|  |  |  |     array z(true); | 
					
						
							|  |  |  |     eval({x, y, z}); | 
					
						
							|  |  |  |     CHECK_EQ(x.item<float>(), 1.0); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     array x(1.0); | 
					
						
							|  |  |  |     array y = ones({2, 2}); | 
					
						
							|  |  |  |     array z(true); | 
					
						
							|  |  |  |     eval({x, y, z}); | 
					
						
							|  |  |  |     CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>()); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test eval multiple") { | 
					
						
							|  |  |  |   auto x = ones({10, 10}); | 
					
						
							|  |  |  |   auto y = ones({10, 10}); | 
					
						
							|  |  |  |   eval({x, y}); | 
					
						
							|  |  |  |   CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto a = x + y; | 
					
						
							|  |  |  |   auto b = x - y; | 
					
						
							|  |  |  |   eval({a, b}); | 
					
						
							|  |  |  |   CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>()); | 
					
						
							|  |  |  |   CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   x = ones({10, 10}); | 
					
						
							|  |  |  |   y = ones({10, 10}); | 
					
						
							|  |  |  |   eval(x, y); | 
					
						
							|  |  |  |   CHECK(array_equal(x, y).item<bool>()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   a = x + y; | 
					
						
							|  |  |  |   b = x - y; | 
					
						
							|  |  |  |   eval(a, b); | 
					
						
							|  |  |  |   CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>()); | 
					
						
							|  |  |  |   CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>()); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  | TEST_CASE("test eval with tracer when not tracing") { | 
					
						
							|  |  |  |   // Since we are not tracing it doesn't matter that the array flags are
 | 
					
						
							|  |  |  |   // tracers they will always be detached.
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   auto x = array(1); | 
					
						
							|  |  |  |   x.set_tracer(true); | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   CHECK(!x.is_tracer()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   eval(x); | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   CHECK(!x.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(x.is_available()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   x = ones({2, 3}); | 
					
						
							|  |  |  |   x.set_tracer(true); | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   eval(x); | 
					
						
							|  |  |  |   CHECK(!x.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(x.is_available()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  | TEST_CASE("test eval graph retention when not tracing") { | 
					
						
							|  |  |  |   // Since we are not tracing it doesn't matter that the array flags are
 | 
					
						
							|  |  |  |   // tracers they will always be detached.
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   auto x = array(1); | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   x.set_tracer(true); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   auto y = array(2); | 
					
						
							|  |  |  |   auto z = x + y; | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   eval(z); | 
					
						
							|  |  |  |   CHECK(!z.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(z.is_available()); | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   CHECK_EQ(z.item<int>(), 3); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   z.set_tracer(false); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   CHECK_EQ(z.item<int>(), 3); | 
					
						
							|  |  |  |   CHECK(!z.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(z.is_available()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   z = x + y; | 
					
						
							|  |  |  |   auto a = z + x; | 
					
						
							|  |  |  |   auto b = a + y; | 
					
						
							| 
									
										
										
										
											2024-01-07 15:16:51 -08:00
										 |  |  |   eval(b); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   CHECK(!z.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(z.is_available()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  |   CHECK(!a.has_primitive()); | 
					
						
							| 
									
										
										
										
											2024-04-17 06:16:02 -07:00
										 |  |  |   CHECK(a.is_available()); | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | } |