mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Ring distributed backend (#1784)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							2235dee906
						
					
				
				
					commit
					ccb61d7aae
				
			| @@ -34,6 +34,8 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|             mx.int32, | ||||
|             mx.uint32, | ||||
|             mx.float32, | ||||
|             mx.float16, | ||||
|             mx.bfloat16, | ||||
|             mx.complex64, | ||||
|         ] | ||||
|         for dt in dtypes: | ||||
|   | ||||
							
								
								
									
										61
									
								
								python/tests/ring_test_distributed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								python/tests/ring_test_distributed.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| # Copyright © 2024 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestRingDistributed(mlx_tests.MLXTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         world = mx.distributed.init(strict=True, backend="ring") | ||||
|  | ||||
|     def test_groups(self): | ||||
|         world = mx.distributed.init() | ||||
|         self.assertEqual(world.size(), 8) | ||||
|         self.assertTrue(0 <= world.rank() < 8) | ||||
|  | ||||
|         world2 = mx.distributed.init() | ||||
|         self.assertEqual(world.size(), world2.size()) | ||||
|         self.assertEqual(world.rank(), world2.rank()) | ||||
|  | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             sub = world.split(world.rank() % 2) | ||||
|  | ||||
|     def test_all_reduce(self): | ||||
|         world = mx.distributed.init() | ||||
|         dtypes = [ | ||||
|             (mx.int8, 0), | ||||
|             (mx.uint8, 0), | ||||
|             (mx.int16, 0), | ||||
|             (mx.uint16, 0), | ||||
|             (mx.int32, 0), | ||||
|             (mx.uint32, 0), | ||||
|             (mx.float32, 1e-6), | ||||
|             (mx.float16, 5e-3), | ||||
|             (mx.bfloat16, 1e-1), | ||||
|             (mx.complex64, 1e-6), | ||||
|         ] | ||||
|         sizes = [ | ||||
|             (7,), | ||||
|             (10,), | ||||
|             (1024,), | ||||
|             (1024, 1024), | ||||
|         ] | ||||
|         key = mx.random.key(0) | ||||
|         for dt, rtol in dtypes: | ||||
|             for sh in sizes: | ||||
|                 x = ( | ||||
|                     mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 | ||||
|                 ).astype(dt) | ||||
|                 y = mx.distributed.all_sum(x[world.rank()]) | ||||
|                 z = sum( | ||||
|                     x[i] for i in range(world.size()) | ||||
|                 )  # to ensure that we don't sum to int32 | ||||
|                 maxrelerror = ((y - z).abs() / z.abs()).max() | ||||
|                 self.assertLessEqual(maxrelerror, rtol) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										25
									
								
								python/tests/run_ring_test.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								python/tests/run_ring_test.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| tmpfile=$(mktemp) | ||||
| cat <<HOSTFILE >$tmpfile | ||||
| [ | ||||
|     ["127.0.0.1:5000"], | ||||
|     ["127.0.0.1:5001"], | ||||
|     ["127.0.0.1:5002"], | ||||
|     ["127.0.0.1:5003"], | ||||
|     ["127.0.0.1:5004"], | ||||
|     ["127.0.0.1:5005"], | ||||
|     ["127.0.0.1:5006"], | ||||
|     ["127.0.0.1:5007"] | ||||
| ] | ||||
| HOSTFILE | ||||
|  | ||||
| ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py" | ||||
|  | ||||
| for i in {0..7}; do | ||||
|     if (($i == 7)); then | ||||
|         sleep 1 | ||||
|     fi | ||||
|     DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test & | ||||
| done | ||||
| wait | ||||
		Reference in New Issue
	
	Block a user