mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
@@ -17,7 +17,10 @@ Group to_group(std::optional<Group> group) {
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_sum(const array& x, std::optional<Group> group_) {
|
||||
array all_sum(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(group),
|
||||
std::make_shared<AllGather>(to_stream(s), group),
|
||||
{x});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user