mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
Comms (#1097)
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
This commit is contained in:
committed by
GitHub
parent
0189ab6ab6
commit
50dfb664db
@@ -6,6 +6,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
||||
107
python/src/distributed.cpp
Normal file
107
python/src/distributed.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_distributed(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"distributed", "mlx.core.distributed: Communication operations");
|
||||
|
||||
nb::class_<distributed::Group>(
|
||||
m,
|
||||
"Group",
|
||||
R"pbcopy(
|
||||
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
||||
processes that can communicate.
|
||||
)pbcopy")
|
||||
.def("rank", &distributed::Group::rank, "Get the rank of this process")
|
||||
.def("size", &distributed::Group::size, "Get the size of the group")
|
||||
.def(
|
||||
"split",
|
||||
&distributed::Group::split,
|
||||
"color"_a,
|
||||
"key"_a = -1,
|
||||
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
||||
R"pbdoc(
|
||||
Split the group to subgroups based on the provided color.
|
||||
|
||||
Processes that use the same color go to the same group. The ``key``
|
||||
argument defines the rank in the new group. The smaller the key the
|
||||
smaller the rank. If the key is negative then the rank in the
|
||||
current group is used.
|
||||
|
||||
Args:
|
||||
color (int): A value to group processes into subgroups.
|
||||
key (int, optional): A key to optionally change the rank ordering
|
||||
of the processes.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
&distributed::is_available,
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"init",
|
||||
&distributed::init,
|
||||
R"pbdoc(
|
||||
Initialize the communication backend and create the global communication group.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
Sum the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&distributed::all_gather,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Gather arrays from all processes.
|
||||
|
||||
Gather the ``x`` arrays from all processes in the group and concatenate
|
||||
them along the first axis. The arrays should all have the same shape.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
gather. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Conbright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
@@ -18,6 +18,7 @@ void init_fft(nb::module_&);
|
||||
void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@@ -37,6 +38,7 @@ NB_MODULE(core, m) {
|
||||
init_linalg(m);
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user