| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | // Copyright © 2023 Apple Inc.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | #include "doctest/doctest.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "mlx/mlx.h"
 | 
					
						
							|  |  |  | #include "mlx/scheduler.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace mlx::core; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test stream management") { | 
					
						
							|  |  |  |   auto s1 = default_stream(default_device()); | 
					
						
							|  |  |  |   CHECK_EQ(s1.device, default_device()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto s2 = new_stream(default_device()); | 
					
						
							|  |  |  |   CHECK_EQ(s2.device, default_device()); | 
					
						
							|  |  |  |   CHECK_NE(s1, s2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Check that default streams have the correct devices
 | 
					
						
							|  |  |  |   if (metal::is_available()) { | 
					
						
							|  |  |  |     auto s_gpu = default_stream(Device::gpu); | 
					
						
							|  |  |  |     CHECK_EQ(s_gpu.device, Device::gpu); | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   auto s_cpu = default_stream(Device::cpu); | 
					
						
							|  |  |  |   CHECK_EQ(s_cpu.device, Device::cpu); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   s_cpu = new_stream(Device::cpu); | 
					
						
							|  |  |  |   CHECK_EQ(s_cpu.device, Device::cpu); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (metal::is_available()) { | 
					
						
							|  |  |  |     auto s_gpu = new_stream(Device::gpu); | 
					
						
							|  |  |  |     CHECK_EQ(s_gpu.device, Device::gpu); | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test asynchronous launch") { | 
					
						
							|  |  |  |   auto s1 = default_stream(default_device()); | 
					
						
							|  |  |  |   auto s2 = new_stream(default_device()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Make sure streams execute asynchronously
 | 
					
						
							|  |  |  |   int x = 1; | 
					
						
							|  |  |  |   auto p1 = std::make_shared<std::promise<void>>(); | 
					
						
							|  |  |  |   auto p2 = std::make_shared<std::promise<void>>(); | 
					
						
							|  |  |  |   auto f1 = p1->get_future().share(); | 
					
						
							|  |  |  |   auto f2 = p2->get_future().share(); | 
					
						
							|  |  |  |   auto fn1 = [&x, p = std::move(p1)]() { | 
					
						
							|  |  |  |     x++; | 
					
						
							|  |  |  |     p->set_value(); | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  |   auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() { | 
					
						
							|  |  |  |     f.wait(); | 
					
						
							|  |  |  |     x *= 5; | 
					
						
							|  |  |  |     p->set_value(); | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // fn2 is launched first and is waiting on fn1 but since
 | 
					
						
							|  |  |  |   // they are on different streams there is no deadlock.
 | 
					
						
							|  |  |  |   scheduler::enqueue(s2, std::move(fn2)); | 
					
						
							|  |  |  |   scheduler::enqueue(s1, std::move(fn1)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   f2.wait(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CHECK_EQ(x, 10); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test stream placement") { | 
					
						
							|  |  |  |   auto s1 = default_stream(default_device()); | 
					
						
							|  |  |  |   auto s2 = new_stream(default_device()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     // Wait on stream 1
 | 
					
						
							|  |  |  |     auto p = std::make_shared<std::promise<void>>(); | 
					
						
							|  |  |  |     auto f = p->get_future().share(); | 
					
						
							|  |  |  |     scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Do some work on stream 2
 | 
					
						
							|  |  |  |     auto x = zeros({100}, float32, s2); | 
					
						
							|  |  |  |     auto y = ones({100}, float32, s2); | 
					
						
							|  |  |  |     auto z = add(x, y, s2); | 
					
						
							|  |  |  |     eval(z); | 
					
						
							|  |  |  |     p->set_value(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     // Wait on stream 1
 | 
					
						
							|  |  |  |     auto p = std::make_shared<std::promise<void>>(); | 
					
						
							|  |  |  |     auto f = p->get_future().share(); | 
					
						
							|  |  |  |     scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Do some work on stream 2
 | 
					
						
							|  |  |  |     auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); }; | 
					
						
							|  |  |  |     auto x = zeros({100}, s2); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // The whole vjp computation should happen
 | 
					
						
							|  |  |  |     // on the second stream otherwise this will hang.
 | 
					
						
							|  |  |  |     auto [out, dout] = vjp(fn, x, ones({100}, s2)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // The whole jvp computation should happen on the
 | 
					
						
							|  |  |  |     // second stream.
 | 
					
						
							|  |  |  |     std::tie(out, dout) = jvp(fn, x, ones({100}, s2)); | 
					
						
							|  |  |  |     eval(out, dout); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     p->set_value(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test scheduler races") { | 
					
						
							|  |  |  |   auto x = zeros({1}); | 
					
						
							|  |  |  |   auto y = zeros({100}); | 
					
						
							|  |  |  |   eval(x, y); | 
					
						
							|  |  |  |   auto a = exp(x); | 
					
						
							|  |  |  |   eval(a); | 
					
						
							|  |  |  |   a = exp(x); | 
					
						
							|  |  |  |   for (int i = 0; i < 10000; ++i) { | 
					
						
							|  |  |  |     y = exp(y); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   eval(a, y); | 
					
						
							|  |  |  | } |