mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			19 lines
		
	
	
		
			411 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			19 lines
		
	
	
		
			411 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import os
 | |
| import unittest
 | |
| 
 | |
| import mlx.core as mx
 | |
| 
 | |
| 
 | |
| class MLXTestCase(unittest.TestCase):
 | |
|     def setUp(self):
 | |
|         self.default = mx.default_device()
 | |
|         device = os.getenv("DEVICE", None)
 | |
|         if device is not None:
 | |
|             device = getattr(mx, device)
 | |
|             mx.set_default_device(device)
 | |
| 
 | |
|     def tearDown(self):
 | |
|         mx.set_default_device(self.default)
 | 
