mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			36 lines
		
	
	
		
			843 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			843 B
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright © 2023 Apple Inc.
 | |
| 
 | |
| #include "doctest/doctest.h"
 | |
| 
 | |
| #include <cstdlib>
 | |
| 
 | |
| #include "mlx/mlx.h"
 | |
| 
 | |
| using namespace mlx::core;
 | |
| 
 | |
| TEST_CASE("test device placement") {
 | |
|   auto device = default_device();
 | |
|   Device d = metal::is_available() ? Device::gpu : Device::cpu;
 | |
|   if (std::getenv("DEVICE") == nullptr) {
 | |
|     CHECK_EQ(device, d);
 | |
|   }
 | |
| 
 | |
|   array x(1.0f);
 | |
|   array y(1.0f);
 | |
|   auto z = add(x, y, default_device());
 | |
|   if (metal::is_available()) {
 | |
|     z = add(x, y, Device::gpu);
 | |
|     z = add(x, y, Device(Device::gpu, 0));
 | |
|   } else {
 | |
|     CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
 | |
|     CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
 | |
|   }
 | |
| 
 | |
|   // Set the default device to the CPU
 | |
|   set_default_device(Device::cpu);
 | |
|   CHECK_EQ(default_device(), Device::cpu);
 | |
| 
 | |
|   // Revert
 | |
|   set_default_device(device);
 | |
| }
 | 
