mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 20:18:15 +08:00
Add docs for the distributed namespace (#1184)
This commit is contained in:

committed by
GitHub

parent
578842954c
commit
0163a8e57a
@@ -56,7 +56,7 @@ namespace detail {
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_reduce_sum(Group group, const array& input, array& output);
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
@@ -260,7 +260,7 @@ Stream communication_stream() {
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input_, array& output) {
|
||||
void all_sum(Group group, const array& input_, array& output) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().all_reduce(
|
||||
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
|
||||
|
@@ -31,7 +31,7 @@ Stream communication_stream() {
|
||||
return comm_stream;
|
||||
}
|
||||
|
||||
void all_reduce_sum(Group group, const array& input, array& output) {}
|
||||
void all_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
|
||||
} // namespace detail
|
||||
|
@@ -17,7 +17,7 @@ Group to_group(std::optional<Group> group) {
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group_) {
|
||||
array all_sum(const array& x, std::optional<Group> group_) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
|
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -24,7 +24,7 @@ void AllReduce::eval_cpu(
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
|
||||
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
@@ -36,7 +36,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
const std::vector<int>& axes) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_reduce_sum(inputs[0], group())}, axes};
|
||||
return {{all_sum(inputs[0], group())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
@@ -48,7 +48,7 @@ std::vector<array> AllReduce::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_reduce_sum(tangents[0], group())};
|
||||
return {all_sum(tangents[0], group())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
|
Reference in New Issue
Block a user