mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix the test and add custom min/max reductions for uncommon MPI types (#2060)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							dfae2c6989
						
					
				
				
					commit
					ddaa4b7dcb
				
			| @@ -30,27 +30,51 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|     def test_all_reduce(self): | ||||
|         world = mx.distributed.init() | ||||
|         dtypes = [ | ||||
|             mx.int8, | ||||
|             mx.uint8, | ||||
|             mx.int16, | ||||
|             mx.uint16, | ||||
|             mx.int32, | ||||
|             mx.uint32, | ||||
|             mx.float32, | ||||
|             mx.float16, | ||||
|             mx.bfloat16, | ||||
|             mx.complex64, | ||||
|             (mx.int8, 0), | ||||
|             (mx.uint8, 0), | ||||
|             (mx.int16, 0), | ||||
|             (mx.uint16, 0), | ||||
|             (mx.int32, 0), | ||||
|             (mx.uint32, 0), | ||||
|             (mx.float32, 1e-6), | ||||
|             (mx.float16, 5e-3), | ||||
|             (mx.bfloat16, 1e-1), | ||||
|             (mx.complex64, 1e-6), | ||||
|         ] | ||||
|         for dt in dtypes: | ||||
|             x = mx.ones((2, 2, 4), dtype=dt) | ||||
|             y = mx.distributed.all_sum(x) | ||||
|             self.assertTrue(mx.all(y == world.size())) | ||||
|         sizes = [ | ||||
|             (7,), | ||||
|             (10,), | ||||
|             (1024,), | ||||
|             (1024, 1024), | ||||
|         ] | ||||
|         key = mx.random.key(0) | ||||
|         group = world.split(world.rank() % 2) | ||||
|  | ||||
|         sub = world.split(world.rank() % 2) | ||||
|         for dt in dtypes: | ||||
|             x = mx.ones((2, 2, 4), dtype=dt) | ||||
|             y = mx.distributed.all_sum(x, group=sub) | ||||
|             self.assertTrue(mx.all(y == sub.size())) | ||||
|         for dt, rtol in dtypes: | ||||
|             for sh in sizes: | ||||
|                 for g in [world, group]: | ||||
|                     x = ( | ||||
|                         mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10 | ||||
|                     ).astype(dt) | ||||
|  | ||||
|                     # All sum | ||||
|                     y = mx.distributed.all_sum(x[g.rank()], group=g) | ||||
|                     z = x.sum(0) | ||||
|                     maxrelerror = (y - z).abs() | ||||
|                     if rtol > 0: | ||||
|                         maxrelerror /= z.abs() | ||||
|                     maxrelerror = maxrelerror.max() | ||||
|                     self.assertLessEqual(maxrelerror, rtol) | ||||
|  | ||||
|                     # All max | ||||
|                     y = mx.distributed.all_max(x[g.rank()], group=g) | ||||
|                     z = x.max(0) | ||||
|                     self.assertTrue(mx.all(y == z)) | ||||
|  | ||||
|                     # All min | ||||
|                     y = mx.distributed.all_min(x[g.rank()], group=g) | ||||
|                     z = x.min(0) | ||||
|                     self.assertTrue(mx.all(y == z)) | ||||
|  | ||||
|     def test_all_gather(self): | ||||
|         world = mx.distributed.init() | ||||
| @@ -124,22 +148,6 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|             x = mx.distributed.recv_like(x, neighbor, group=pairs) | ||||
|         mx.eval(y, x) | ||||
|  | ||||
|     def test_min_max(self): | ||||
|         world = mx.distributed.init() | ||||
|         base = mx.arange(16).reshape(4, 4) | ||||
|         x = base + world.rank() * 32 | ||||
|  | ||||
|         def _test_reduction(reduction: str = "all_max"): | ||||
|  | ||||
|             target = base + ((world.size() - 1) * 16) * (reduction == "max") | ||||
|             reducer = getattr(mx.distributed, reduction) | ||||
|             y = reducer(x) | ||||
|  | ||||
|             self.assertTrue(mx.allclose(y, target)) | ||||
|  | ||||
|         for reduction in ["all_max", "all_min"]: | ||||
|             _test_reduction(reduction) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|                 x = ( | ||||
|                     mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 | ||||
|                 ).astype(dt) | ||||
|                 for reduction in reductions: | ||||
|                     reducer_distributed = getattr(mx.distributed, f"all_{reduction}") | ||||
|                     y = reducer_distributed(x[world.rank()]) | ||||
|  | ||||
|                     reducer = getattr(mx, reduction) | ||||
|                     z = reducer(x, axis=0) | ||||
|                     mx.eval(y, z) | ||||
|                 # All sum | ||||
|                 y = mx.distributed.all_sum(x[world.rank()]) | ||||
|                 z = x.sum(0) | ||||
|                 maxrelerror = (y - z).abs() | ||||
|                 if rtol > 0: | ||||
|                     maxrelerror /= z.abs() | ||||
|                 maxrelerror = maxrelerror.max() | ||||
|                 self.assertLessEqual(maxrelerror, rtol) | ||||
|  | ||||
|                     maxrelerror = ((y - z).abs() / z.abs()).max() | ||||
|                     self.assertLessEqual(maxrelerror, rtol) | ||||
|                 # All max | ||||
|                 y = mx.distributed.all_max(x[world.rank()]) | ||||
|                 z = x.max(0) | ||||
|                 self.assertTrue(mx.all(y == z)) | ||||
|  | ||||
|                 # All min | ||||
|                 y = mx.distributed.all_min(x[world.rank()]) | ||||
|                 z = x.min(0) | ||||
|                 self.assertTrue(mx.all(y == z)) | ||||
|  | ||||
|     def test_all_gather(self): | ||||
|         world = mx.distributed.init() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user