MPI ops in GPU stream for faster comms (#1356)

This commit is contained in:
Awni Hannun
2024-08-26 15:12:50 -07:00
committed by GitHub
parent 2fdf9eb535
commit 5f7d19d1f5
14 changed files with 220 additions and 26 deletions

View File

@@ -3,6 +3,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/variant.h>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
@@ -74,8 +75,9 @@ void init_distributed(nb::module_& parent_module) {
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
All reduce sum.
@@ -86,6 +88,8 @@ void init_distributed(nb::module_& parent_module) {
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The sum of all ``x`` arrays.
@@ -97,8 +101,9 @@ void init_distributed(nb::module_& parent_module) {
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Gather arrays from all processes.
@@ -110,6 +115,8 @@ void init_distributed(nb::module_& parent_module) {
group (Group): The group of processes that will participate in the
gather. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The concatenation of all ``x`` arrays.