mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			284 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			284 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2024 Apple Inc.
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import mlx_tests
 | |
| from mlx.nn.layers.distributed import shard_inplace, shard_linear
 | |
| from mlx.nn.utils import average_gradients
 | |
| 
 | |
| 
 | |
| class TestNCCLDistributed(mlx_tests.MLXTestCase):
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
|         world = mx.distributed.init(strict=True, backend="nccl")
 | |
|         rank = world.rank()
 | |
|         mx.set_default_device(mx.Device(mx.gpu, rank % 8))
 | |
| 
 | |
|     def test_all_reduce(self):
 | |
|         world = mx.distributed.init()
 | |
|         dtypes = [
 | |
|             (mx.int8, 0),
 | |
|             (mx.uint8, 0),
 | |
|             (mx.int32, 0),
 | |
|             (mx.uint32, 0),
 | |
|             (mx.float32, 1e-6),
 | |
|             (mx.float16, 5e-3),
 | |
|             (mx.bfloat16, 1e-1),
 | |
|         ]
 | |
|         sizes = [
 | |
|             (7,),
 | |
|             (10,),
 | |
|             (1024,),
 | |
|             (1024, 1024),
 | |
|         ]
 | |
|         key = mx.random.key(0)
 | |
| 
 | |
|         for dt, rtol in dtypes:
 | |
|             for sh in sizes:
 | |
|                 x = (
 | |
|                     mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
 | |
|                 ).astype(dt)
 | |
| 
 | |
|                 # All sum
 | |
|                 y = mx.distributed.all_sum(x[world.rank()])
 | |
|                 z = x.sum(0)
 | |
|                 maxrelerror = (y - z).abs()
 | |
|                 if rtol > 0:
 | |
|                     maxrelerror /= z.abs()
 | |
|                 maxrelerror = maxrelerror.max()
 | |
|                 self.assertLessEqual(maxrelerror, rtol)
 | |
| 
 | |
|     def test_average_gradients(self):
 | |
|         original_all_sum = mx.distributed.all_sum
 | |
|         n_calls = 0
 | |
|         xtype = None
 | |
| 
 | |
|         def new_all_sum(x, **kwargs):
 | |
|             nonlocal n_calls
 | |
|             nonlocal xtype
 | |
| 
 | |
|             n_calls += 1
 | |
|             if xtype is not None:
 | |
|                 self.assertEqual(xtype, x.dtype)
 | |
| 
 | |
|             return original_all_sum(x, **kwargs)
 | |
| 
 | |
|         mx.distributed.all_sum = new_all_sum
 | |
|         try:
 | |
|             grads = [mx.ones(10) for i in range(10)]
 | |
|             new_grads = average_gradients(grads)
 | |
|             mx.eval(new_grads)
 | |
|             self.assertEqual(len(new_grads), 10)
 | |
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads))
 | |
|             self.assertEqual(n_calls, 1)
 | |
| 
 | |
|             n_calls = 0
 | |
|             new_grads = average_gradients(grads, all_reduce_size=4 * 50)
 | |
|             mx.eval(new_grads)
 | |
|             self.assertEqual(len(new_grads), 10)
 | |
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads))
 | |
|             self.assertEqual(n_calls, 2)
 | |
| 
 | |
|             n_calls = 0
 | |
|             new_grads = average_gradients(grads, all_reduce_size=0)
 | |
|             mx.eval(new_grads)
 | |
|             self.assertEqual(len(new_grads), 10)
 | |
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads))
 | |
|             self.assertEqual(n_calls, 10)
 | |
| 
 | |
|             n_calls = 0
 | |
|             xtype = mx.float16
 | |
|             new_grads = average_gradients(
 | |
|                 grads,
 | |
|                 all_reduce_size=2 * 50,
 | |
|                 communication_type=mx.float16,
 | |
|             )
 | |
|             mx.eval(new_grads)
 | |
|             self.assertEqual(len(new_grads), 10)
 | |
|             self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
 | |
|             self.assertTrue(all(mx.all(g == 1) for g in new_grads))
 | |
|             self.assertEqual(n_calls, 2)
 | |
| 
 | |
|         finally:
 | |
|             mx.distributed.all_sum = original_all_sum
 | |
| 
 | |
|     def test_donation(self):
 | |
|         x = mx.random.normal((1024,))
 | |
|         mx.eval(x)
 | |
|         mx.synchronize()
 | |
| 
 | |
|         mx.reset_peak_memory()
 | |
|         scale = mx.array(2.0)
 | |
|         y = mx.distributed.all_sum(x)
 | |
|         mx.eval(y)
 | |
|         mx.synchronize()
 | |
|         all_sum_only = mx.get_peak_memory()
 | |
|         y = mx.distributed.all_sum(x) * scale
 | |
|         mx.eval(y)
 | |
|         mx.synchronize()
 | |
|         all_sum_with_binary = mx.get_peak_memory()
 | |
| 
 | |
|         self.assertEqual(all_sum_only, all_sum_with_binary)
 | |
| 
 | |
|     def test_shard_linear(self):
 | |
|         # Seed the prng to have the same inputs and weights generated everywhere
 | |
|         mx.random.seed(0xF0F0F0F0)
 | |
| 
 | |
|         # Prepare inputs
 | |
|         world = mx.distributed.init()
 | |
|         part = (
 | |
|             slice(None),
 | |
|             slice(
 | |
|                 world.rank() * 1024 // world.size(),
 | |
|                 (world.rank() + 1) * 1024 // world.size(),
 | |
|             ),
 | |
|         )
 | |
|         x = mx.random.normal((4, 1024))
 | |
| 
 | |
|         # Create and shard some linear layers
 | |
|         lin = nn.Linear(1024, 1024, bias=True)
 | |
|         slin1 = shard_linear(lin, "all-to-sharded")
 | |
|         slin2 = shard_linear(lin, "sharded-to-all")
 | |
|         y = lin(x)
 | |
|         y1 = slin1(x)
 | |
|         y2 = slin2(x[part])
 | |
|         self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
 | |
|         self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
 | |
| 
 | |
|         # Check the backward works as expected
 | |
|         def dummy_loss(model, x, y):
 | |
|             return (model(x) * y).sum()
 | |
| 
 | |
|         mod = nn.Sequential(
 | |
|             nn.Linear(128, 128),
 | |
|             nn.Linear(128, 128),
 | |
|             nn.Linear(128, 128),
 | |
|             nn.Linear(128, 128),
 | |
|         )
 | |
|         smod = nn.Sequential(
 | |
|             shard_linear(mod.layers[0], "all-to-sharded"),
 | |
|             shard_linear(mod.layers[1], "sharded-to-all"),
 | |
|             shard_linear(mod.layers[2], "all-to-sharded"),
 | |
|             shard_linear(mod.layers[3], "sharded-to-all"),
 | |
|         )
 | |
| 
 | |
|         grad1 = nn.value_and_grad(mod, dummy_loss)
 | |
|         grad2 = nn.value_and_grad(smod, dummy_loss)
 | |
| 
 | |
|         x = mx.random.normal((4, 128))
 | |
|         y = mx.random.normal((4, 128))
 | |
| 
 | |
|         l1, g1 = grad1(mod, x, y)
 | |
|         l2, g2 = grad2(smod, x, y)
 | |
|         mx.eval(l1, g1, l2, g2)
 | |
| 
 | |
|         part = slice(
 | |
|             world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
 | |
|         )
 | |
| 
 | |
|         self.assertTrue(mx.allclose(l1, l2))
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][0]["weight"][part],
 | |
|                 g2["layers"][0]["weight"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][2]["weight"][part],
 | |
|                 g2["layers"][2]["weight"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][1]["weight"][:, part],
 | |
|                 g2["layers"][1]["weight"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][3]["weight"][:, part],
 | |
|                 g2["layers"][3]["weight"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][0]["bias"][part],
 | |
|                 g2["layers"][0]["bias"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][2]["bias"][part],
 | |
|                 g2["layers"][2]["bias"],
 | |
|                 atol=1e-6,
 | |
|                 rtol=1e-4,
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
 | |
|             )
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             mx.allclose(
 | |
|                 g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_shard_predicate(self):
 | |
|         mx.random.seed(0xF0F0F0F0)
 | |
| 
 | |
|         class MyConv(nn.Module):
 | |
|             def __init__(self, *args, **kwargs):
 | |
|                 super().__init__()
 | |
|                 self.aggregate = kwargs.pop("aggregate", False)
 | |
|                 self.conv = nn.Conv2d(*args, **kwargs)
 | |
| 
 | |
|             def __call__(self, x):
 | |
|                 x = self.conv(x)
 | |
|                 if self.aggregate:
 | |
|                     x = mx.distributed.all_sum(x)
 | |
|                 return x
 | |
| 
 | |
|         def sharding(path, weight):
 | |
|             parts = path.split(".")
 | |
|             even = int(parts[1]) % 2 == 0
 | |
|             if even:
 | |
|                 return 0
 | |
|             else:
 | |
|                 return -1 if parts[-1] != "bias" else None
 | |
| 
 | |
|         mod = nn.Sequential(
 | |
|             MyConv(3, 128, kernel_size=3),
 | |
|             MyConv(128, 128, kernel_size=3),
 | |
|             MyConv(128, 128, kernel_size=3),
 | |
|             MyConv(128, 3, kernel_size=3),
 | |
|         )
 | |
|         smod = nn.Sequential(
 | |
|             MyConv(3, 128, kernel_size=3),
 | |
|             MyConv(128, 128, kernel_size=3, aggregate=True),
 | |
|             MyConv(128, 128, kernel_size=3),
 | |
|             MyConv(128, 3, kernel_size=3, aggregate=True),
 | |
|         )
 | |
|         smod.update(mod.parameters())
 | |
|         shard_inplace(smod, sharding)
 | |
| 
 | |
|         x = mx.random.normal((4, 16, 16, 3))
 | |
|         y1 = mod(x)
 | |
|         y2 = smod(x)
 | |
|         self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
