| 
									
										
										
										
											2023-11-30 11:12:53 -08:00
										 |  |  | // Copyright © 2023 Apple Inc.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-29 10:52:08 -08:00
										 |  |  | #include "doctest/doctest.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cstdlib>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "mlx/mlx.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace mlx::core; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST_CASE("test device placement") { | 
					
						
							|  |  |  |   auto device = default_device(); | 
					
						
							|  |  |  |   Device d = metal::is_available() ? Device::gpu : Device::cpu; | 
					
						
							|  |  |  |   if (std::getenv("DEVICE") == nullptr) { | 
					
						
							|  |  |  |     CHECK_EQ(device, d); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   array x(1.0f); | 
					
						
							|  |  |  |   array y(1.0f); | 
					
						
							|  |  |  |   auto z = add(x, y, default_device()); | 
					
						
							|  |  |  |   if (metal::is_available()) { | 
					
						
							|  |  |  |     z = add(x, y, Device::gpu); | 
					
						
							|  |  |  |     z = add(x, y, Device(Device::gpu, 0)); | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument); | 
					
						
							|  |  |  |     CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Set the default device to the CPU
 | 
					
						
							|  |  |  |   set_default_device(Device::cpu); | 
					
						
							|  |  |  |   CHECK_EQ(default_device(), Device::cpu); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Revert
 | 
					
						
							|  |  |  |   set_default_device(device); | 
					
						
							|  |  |  | } |