mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add docs for the distributed namespace (#1184)
This commit is contained in:

committed by
GitHub

parent
578842954c
commit
0163a8e57a
@@ -69,13 +69,13 @@ void init_distributed(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"all_sum",
|
||||
&distributed::all_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
|
@@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x)
|
||||
y = mx.distributed.all_sum(x)
|
||||
self.assertTrue(mx.all(y == world.size()))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
||||
y = mx.distributed.all_sum(x, group=sub)
|
||||
self.assertTrue(mx.all(y == sub.size()))
|
||||
|
||||
def test_all_gather(self):
|
||||
@@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
sub_2 = world.split(world.rank() % 2)
|
||||
|
||||
x = mx.ones((1, 8)) * world.rank()
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
||||
y = mx.distributed.all_sum(x, group=sub_1)
|
||||
z = mx.distributed.all_gather(y, group=sub_2)
|
||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||
|
||||
|
Reference in New Issue
Block a user