mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Distributed layers (#1270)
This commit is contained in:

committed by
GitHub

parent
69e4dd506b
commit
4eef8102c9
@@ -3,11 +3,14 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
from mlx.nn.utils import average_gradients
|
||||
import mlx_distributed_tests
|
||||
|
||||
|
||||
class TestDistributed(mlx_tests.MLXTestCase):
|
||||
class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="mpi")
|
||||
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
self.assertEqual(world.size(), 8)
|
||||
@@ -121,77 +124,6 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(y, x)
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
xtype = None
|
||||
|
||||
def new_all_sum(x, **kwargs):
|
||||
nonlocal n_calls
|
||||
nonlocal xtype
|
||||
|
||||
n_calls += 1
|
||||
if xtype is not None:
|
||||
self.assertEqual(xtype, x.dtype)
|
||||
|
||||
return original_all_sum(x, **kwargs)
|
||||
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 10)
|
||||
|
||||
n_calls = 0
|
||||
xtype = mx.float16
|
||||
new_grads = average_gradients(
|
||||
grads, all_reduce_size=2 * 50, communication_type=mx.float16
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user