mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Distributed layers (#1270)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							69e4dd506b
						
					
				
				
					commit
					4eef8102c9
				
			
							
								
								
									
										250
									
								
								python/tests/mlx_distributed_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								python/tests/mlx_distributed_tests.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,250 @@ | ||||
| # Copyright © 2025 Apple Inc. | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx_tests | ||||
| from mlx.nn.layers.distributed import shard_inplace, shard_linear | ||||
| from mlx.nn.utils import average_gradients | ||||
|  | ||||
|  | ||||
| class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase): | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|         xtype = None | ||||
|  | ||||
|         def new_all_sum(x, **kwargs): | ||||
|             nonlocal n_calls | ||||
|             nonlocal xtype | ||||
|  | ||||
|             n_calls += 1 | ||||
|             if xtype is not None: | ||||
|                 self.assertEqual(xtype, x.dtype) | ||||
|  | ||||
|             return original_all_sum(x, **kwargs) | ||||
|  | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|  | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 10) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             xtype = mx.float16 | ||||
|             new_grads = average_gradients( | ||||
|                 grads, all_reduce_size=2 * 50, communication_type=mx.float16 | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         mx.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_only = mx.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|  | ||||
|     def test_shard_linear(self): | ||||
|         # Seed the prng to have the same inputs and weights generated everywhere | ||||
|         mx.random.seed(0xF0F0F0F0) | ||||
|  | ||||
|         # Prepare inputs | ||||
|         world = mx.distributed.init() | ||||
|         part = ( | ||||
|             slice(None), | ||||
|             slice( | ||||
|                 world.rank() * 1024 // world.size(), | ||||
|                 (world.rank() + 1) * 1024 // world.size(), | ||||
|             ), | ||||
|         ) | ||||
|         x = mx.random.normal((4, 1024)) | ||||
|  | ||||
|         # Create and shard some linear layers | ||||
|         lin = nn.Linear(1024, 1024, bias=True) | ||||
|         slin1 = shard_linear(lin, "all-to-sharded") | ||||
|         slin2 = shard_linear(lin, "sharded-to-all") | ||||
|         y = lin(x) | ||||
|         y1 = slin1(x) | ||||
|         y2 = slin2(x[part]) | ||||
|         self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) | ||||
|         self.assertTrue(mx.allclose(y[part], y1)) | ||||
|  | ||||
|         # And their quant versions | ||||
|         qlin = lin.to_quantized() | ||||
|         slin1 = shard_linear(qlin, "all-to-sharded") | ||||
|         slin2 = shard_linear(qlin, "sharded-to-all") | ||||
|         y = qlin(x) | ||||
|         y1 = slin1(x) | ||||
|         y2 = slin2(x[part]) | ||||
|         self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4)) | ||||
|         self.assertTrue(mx.allclose(y[part], y1)) | ||||
|  | ||||
|         # Check the backward works as expected | ||||
|         def dummy_loss(model, x, y): | ||||
|             return (model(x) * y).sum() | ||||
|  | ||||
|         mod = nn.Sequential( | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|             nn.Linear(128, 128), | ||||
|         ) | ||||
|         smod = nn.Sequential( | ||||
|             shard_linear(mod.layers[0], "all-to-sharded"), | ||||
|             shard_linear(mod.layers[1], "sharded-to-all"), | ||||
|             shard_linear(mod.layers[2], "all-to-sharded"), | ||||
|             shard_linear(mod.layers[3], "sharded-to-all"), | ||||
|         ) | ||||
|  | ||||
|         grad1 = nn.value_and_grad(mod, dummy_loss) | ||||
|         grad2 = nn.value_and_grad(smod, dummy_loss) | ||||
|  | ||||
|         x = mx.random.normal((4, 128)) | ||||
|         y = mx.random.normal((4, 128)) | ||||
|  | ||||
|         l1, g1 = grad1(mod, x, y) | ||||
|         l2, g2 = grad2(smod, x, y) | ||||
|         mx.eval(l1, g1, l2, g2) | ||||
|  | ||||
|         part = slice( | ||||
|             world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size() | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(l1, l2)) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][0]["weight"][part], | ||||
|                 g2["layers"][0]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][2]["weight"][part], | ||||
|                 g2["layers"][2]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][1]["weight"][:, part], | ||||
|                 g2["layers"][1]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][3]["weight"][:, part], | ||||
|                 g2["layers"][3]["weight"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][0]["bias"][part], | ||||
|                 g2["layers"][0]["bias"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][2]["bias"][part], | ||||
|                 g2["layers"][2]["bias"], | ||||
|                 atol=1e-6, | ||||
|                 rtol=1e-4, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4 | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             mx.allclose( | ||||
|                 g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4 | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def test_shard_predicate(self): | ||||
|         mx.random.seed(0xF0F0F0F0) | ||||
|  | ||||
|         class MyConv(nn.Module): | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 super().__init__() | ||||
|                 self.aggregate = kwargs.pop("aggregate", False) | ||||
|                 self.conv = nn.Conv2d(*args, **kwargs) | ||||
|  | ||||
|             def __call__(self, x): | ||||
|                 x = self.conv(x) | ||||
|                 if self.aggregate: | ||||
|                     x = mx.distributed.all_sum(x) | ||||
|                 return x | ||||
|  | ||||
|         def sharding(path, weight): | ||||
|             parts = path.split(".") | ||||
|             even = int(parts[1]) % 2 == 0 | ||||
|             if even: | ||||
|                 return 0 | ||||
|             else: | ||||
|                 return -1 if parts[-1] != "bias" else None | ||||
|  | ||||
|         mod = nn.Sequential( | ||||
|             MyConv(3, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 3, kernel_size=3), | ||||
|         ) | ||||
|         smod = nn.Sequential( | ||||
|             MyConv(3, 128, kernel_size=3), | ||||
|             MyConv(128, 128, kernel_size=3, aggregate=True), | ||||
|             MyConv(128, 128, kernel_size=3), | ||||
|             MyConv(128, 3, kernel_size=3, aggregate=True), | ||||
|         ) | ||||
|         smod.update(mod.parameters()) | ||||
|         shard_inplace(smod, sharding) | ||||
|  | ||||
|         x = mx.random.normal((4, 16, 16, 3)) | ||||
|         y1 = mod(x) | ||||
|         y2 = smod(x) | ||||
|         self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) | ||||
| @@ -3,11 +3,14 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| from mlx.nn.utils import average_gradients | ||||
| import mlx_distributed_tests | ||||
|  | ||||
|  | ||||
| class TestDistributed(mlx_tests.MLXTestCase): | ||||
| class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         world = mx.distributed.init(strict=True, backend="mpi") | ||||
|  | ||||
|     def test_groups(self): | ||||
|         world = mx.distributed.init() | ||||
|         self.assertEqual(world.size(), 8) | ||||
| @@ -121,77 +124,6 @@ class TestDistributed(mlx_tests.MLXTestCase): | ||||
|             x = mx.distributed.recv_like(x, neighbor, group=pairs) | ||||
|         mx.eval(y, x) | ||||
|  | ||||
|     def test_average_gradients(self): | ||||
|         original_all_sum = mx.distributed.all_sum | ||||
|         n_calls = 0 | ||||
|         xtype = None | ||||
|  | ||||
|         def new_all_sum(x, **kwargs): | ||||
|             nonlocal n_calls | ||||
|             nonlocal xtype | ||||
|  | ||||
|             n_calls += 1 | ||||
|             if xtype is not None: | ||||
|                 self.assertEqual(xtype, x.dtype) | ||||
|  | ||||
|             return original_all_sum(x, **kwargs) | ||||
|  | ||||
|         mx.distributed.all_sum = new_all_sum | ||||
|  | ||||
|         try: | ||||
|             grads = [mx.ones(10) for i in range(10)] | ||||
|             new_grads = average_gradients(grads) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 1) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             new_grads = average_gradients(grads, all_reduce_size=0) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 10) | ||||
|  | ||||
|             n_calls = 0 | ||||
|             xtype = mx.float16 | ||||
|             new_grads = average_gradients( | ||||
|                 grads, all_reduce_size=2 * 50, communication_type=mx.float16 | ||||
|             ) | ||||
|             mx.eval(new_grads) | ||||
|             self.assertEqual(len(new_grads), 10) | ||||
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) | ||||
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads)) | ||||
|             self.assertEqual(n_calls, 2) | ||||
|  | ||||
|         finally: | ||||
|             mx.distributed.all_sum = original_all_sum | ||||
|  | ||||
|     def test_donation(self): | ||||
|         x = mx.random.normal((1024,)) | ||||
|         mx.eval(x) | ||||
|         mx.synchronize() | ||||
|  | ||||
|         mx.reset_peak_memory() | ||||
|         scale = mx.array(2.0) | ||||
|         y = mx.distributed.all_sum(x) | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_only = mx.get_peak_memory() | ||||
|         y = mx.distributed.all_sum(x) * scale | ||||
|         mx.eval(y) | ||||
|         mx.synchronize() | ||||
|         all_sum_with_binary = mx.get_peak_memory() | ||||
|  | ||||
|         self.assertEqual(all_sum_only, all_sum_with_binary) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -3,10 +3,10 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| import mlx_distributed_tests | ||||
|  | ||||
|  | ||||
| class TestRingDistributed(mlx_tests.MLXTestCase): | ||||
| class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         world = mx.distributed.init(strict=True, backend="ring") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user