Add docs for the distributed namespace (#1184)

This commit is contained in:
Angelos Katharopoulos
2024-06-06 11:37:00 -07:00
committed by GitHub
parent 578842954c
commit 0163a8e57a
12 changed files with 202 additions and 15 deletions

View File

@@ -56,7 +56,7 @@ namespace detail {
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_reduce_sum(Group group, const array& input, array& output);
void all_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);

View File

@@ -260,7 +260,7 @@ Stream communication_stream() {
return comm_stream;
}
void all_reduce_sum(Group group, const array& input_, array& output) {
void all_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE

View File

@@ -31,7 +31,7 @@ Stream communication_stream() {
return comm_stream;
}
void all_reduce_sum(Group group, const array& input, array& output) {}
void all_sum(Group group, const array& input, array& output) {}
void all_gather(Group group, const array& input, array& output) {}
} // namespace detail

View File

@@ -17,7 +17,7 @@ Group to_group(std::optional<Group> group) {
} // namespace
array all_reduce_sum(const array& x, std::optional<Group> group_) {
array all_sum(const array& x, std::optional<Group> group_) {
auto group = to_group(group_);
if (group.size() == 1) {

View File

@@ -8,7 +8,7 @@
namespace mlx::core::distributed {
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
} // namespace mlx::core::distributed

View File

@@ -24,7 +24,7 @@ void AllReduce::eval_cpu(
switch (reduce_type_) {
case Sum:
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
@@ -36,7 +36,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
const std::vector<int>& axes) {
switch (reduce_type_) {
case Sum:
return {{all_reduce_sum(inputs[0], group())}, axes};
return {{all_sum(inputs[0], group())}, axes};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
@@ -48,7 +48,7 @@ std::vector<array> AllReduce::jvp(
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Sum:
return {all_reduce_sum(tangents[0], group())};
return {all_sum(tangents[0], group())};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}