mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			38 lines
		
	
	
		
			845 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			38 lines
		
	
	
		
			845 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import io
 | |
| import unittest
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx_tests
 | |
| 
 | |
| 
 | |
| class TestGraph(mlx_tests.MLXTestCase):
 | |
|     def test_to_dot(self):
 | |
|         # Simply test that a few cases run.
 | |
|         # Nothing too specific about the graph format
 | |
|         # for now to keep it flexible
 | |
|         a = mx.array(1.0)
 | |
|         f = io.StringIO()
 | |
|         mx.export_to_dot(f, a)
 | |
|         f.seek(0)
 | |
|         self.assertTrue(len(f.read()) > 0)
 | |
| 
 | |
|         b = mx.array(2.0)
 | |
|         c = a + b
 | |
|         f = io.StringIO()
 | |
|         mx.export_to_dot(f, c)
 | |
|         f.seek(0)
 | |
|         self.assertTrue(len(f.read()) > 0)
 | |
| 
 | |
|         # Multi output case
 | |
|         c = mx.divmod(a, b)
 | |
|         f = io.StringIO()
 | |
|         mx.export_to_dot(f, *c)
 | |
|         f.seek(0)
 | |
|         self.assertTrue(len(f.read()) > 0)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
