Ring distributed backend (#1784)

This commit is contained in:
Angelos Katharopoulos
2025-01-27 22:15:01 -08:00
committed by GitHub
parent 2235dee906
commit ccb61d7aae
17 changed files with 1078 additions and 44 deletions

View File

@@ -3,6 +3,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
@@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) {
"init",
&mx::distributed::init,
"strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"),
"backend"_a = "any",
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"),
R"pbdoc(
Initialize the communication backend and create the global communication group.
Example:
import mlx.core as mx
group = mx.distributed.init(backend="ring")
Args:
strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Select a specific distributed backend to
initialize. If set to ``any`` then try all available backends and
return the first one that succeeds. Subsequent calls will return
the first backend that was initialized. Default: ``any``
Returns:
Group: The group representing all the launched processes.

View File

@@ -34,6 +34,8 @@ class TestDistributed(mlx_tests.MLXTestCase):
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64,
]
for dt in dtypes:

View File

@@ -0,0 +1,61 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestRingDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="ring")
def test_groups(self):
world = mx.distributed.init()
self.assertEqual(world.size(), 8)
self.assertTrue(0 <= world.rank() < 8)
world2 = mx.distributed.init()
self.assertEqual(world.size(), world2.size())
self.assertEqual(world.rank(), world2.rank())
with self.assertRaises(RuntimeError):
sub = world.split(world.rank() % 2)
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
y = mx.distributed.all_sum(x[world.rank()])
z = sum(
x[i] for i in range(world.size())
) # to ensure that we don't sum to int32
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,25 @@
#!/bin/bash
tmpfile=$(mktemp)
cat <<HOSTFILE >$tmpfile
[
["127.0.0.1:5000"],
["127.0.0.1:5001"],
["127.0.0.1:5002"],
["127.0.0.1:5003"],
["127.0.0.1:5004"],
["127.0.0.1:5005"],
["127.0.0.1:5006"],
["127.0.0.1:5007"]
]
HOSTFILE
ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py"
for i in {0..7}; do
if (($i == 7)); then
sleep 1
fi
DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test &
done
wait