Add docs for the distributed namespace (#1184)

This commit is contained in:
Angelos Katharopoulos
2024-06-06 11:37:00 -07:00
committed by GitHub
parent 578842954c
commit 0163a8e57a
12 changed files with 202 additions and 15 deletions

View File

@@ -69,13 +69,13 @@ void init_distributed(nb::module_& parent_module) {
)pbdoc");
m.def(
"all_reduce_sum",
&distributed::all_reduce_sum,
"all_sum",
&distributed::all_sum,
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
nb::sig(
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
R"pbdoc(
All reduce sum.

View File

@@ -37,13 +37,13 @@ class TestDistributed(mlx_tests.MLXTestCase):
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_reduce_sum(x)
y = mx.distributed.all_sum(x)
self.assertTrue(mx.all(y == world.size()))
sub = world.split(world.rank() % 2)
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_reduce_sum(x, group=sub)
y = mx.distributed.all_sum(x, group=sub)
self.assertTrue(mx.all(y == sub.size()))
def test_all_gather(self):
@@ -87,7 +87,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
sub_2 = world.split(world.rank() % 2)
x = mx.ones((1, 8)) * world.rank()
y = mx.distributed.all_reduce_sum(x, group=sub_1)
y = mx.distributed.all_sum(x, group=sub_1)
z = mx.distributed.all_gather(y, group=sub_2)
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)