mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Add docs for the distributed namespace (#1184)
This commit is contained in:

committed by
GitHub

parent
578842954c
commit
0163a8e57a
@@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x)
|
||||
y = mx.distributed.all_sum(x)
|
||||
self.assertTrue(mx.all(y == world.size()))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
||||
y = mx.distributed.all_sum(x, group=sub)
|
||||
self.assertTrue(mx.all(y == sub.size()))
|
||||
|
||||
def test_all_gather(self):
|
||||
@@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
sub_2 = world.split(world.rank() % 2)
|
||||
|
||||
x = mx.ones((1, 8)) * world.rank()
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
||||
y = mx.distributed.all_sum(x, group=sub_1)
|
||||
z = mx.distributed.all_gather(y, group=sub_2)
|
||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||
|
||||
|
Reference in New Issue
Block a user