mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
@@ -74,8 +75,9 @@ void init_distributed(nb::module_& parent_module) {
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
@@ -86,6 +88,8 @@ void init_distributed(nb::module_& parent_module) {
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
@@ -97,8 +101,9 @@ void init_distributed(nb::module_& parent_module) {
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Gather arrays from all processes.
|
||||
|
||||
@@ -110,6 +115,8 @@ void init_distributed(nb::module_& parent_module) {
|
||||
group (Group): The group of processes that will participate in the
|
||||
gather. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
|
Reference in New Issue
Block a user