MPI ops in GPU stream for faster comms (#1356)

This commit is contained in:
Awni Hannun
2024-08-26 15:12:50 -07:00
committed by GitHub
parent 2fdf9eb535
commit 5f7d19d1f5
14 changed files with 220 additions and 26 deletions

View File

@@ -17,7 +17,10 @@ Group to_group(std::optional<Group> group) {
} // namespace
array all_sum(const array& x, std::optional<Group> group_) {
array all_sum(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(group, AllReduce::Sum),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
{x});
}
array all_gather(const array& x, std::optional<Group> group_) {
array all_gather(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional<Group> group_) {
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(group),
std::make_shared<AllGather>(to_stream(s), group),
{x});
}