Ring distributed backend (#1784)

This commit is contained in:
Angelos Katharopoulos
2025-01-27 22:15:01 -08:00
committed by GitHub
parent 2235dee906
commit ccb61d7aae
17 changed files with 1078 additions and 44 deletions

View File

@@ -3,6 +3,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
@@ -58,14 +59,26 @@ void init_distributed(nb::module_& parent_module) {
"init",
&mx::distributed::init,
"strict"_a = false,
nb::sig("def init(strict: bool = False) -> Group"),
"backend"_a = "any",
nb::sig("def init(strict: bool = False, backend: str = 'any') -> Group"),
R"pbdoc(
Initialize the communication backend and create the global communication group.
Example:
import mlx.core as mx
group = mx.distributed.init(backend="ring")
Args:
strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Select a specific distributed backend to
initialize. If set to ``any`` then try all available backends and
return the first one that succeeds. Subsequent calls will return
the first backend that was initialized. Default: ``any``
Returns:
Group: The group representing all the launched processes.