Put the decision of the comm stream to the group

This commit is contained in:
Angelos Katharopoulos 2025-08-20 18:21:07 -07:00
parent 3bb6b1d44a
commit eeb5a0d63f
7 changed files with 47 additions and 24 deletions

View File

@ -13,6 +13,10 @@ namespace mlx::core::distributed {
namespace detail {
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
return group.raw_group()->communication_stream(s);
}
void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
@ -39,6 +43,10 @@ void recv(Group group, array& out, int src, Stream stream) {
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s);
}
int rank() override {
return 0;
}

View File

@ -3,7 +3,9 @@
#pragma once
#include <memory>
#include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core::distributed {

View File

@ -13,10 +13,15 @@ class GroupImpl {
public:
virtual ~GroupImpl() {}
// Choose the stream this communication group can operate on
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
// Group operations
virtual int rank() = 0;
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
// Actual communication operations
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
@ -25,6 +30,9 @@ class GroupImpl {
virtual void all_min(const array& input, array& output, Stream stream) = 0;
};
/* Define the MLX stream that the communication should happen in. */
Stream communication_stream(Group group, StreamOrDevice s = {});
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output, Stream stream);

View File

@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
}
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override {
if (rank_ < 0) {
mpi().rank(comm_, &rank_);

View File

@ -17,6 +17,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/dtype_utils.h"
#include "mlx/utils.h"
namespace mlx::core::distributed::nccl {
@ -255,6 +256,10 @@ class NCCLGroup : public GroupImpl {
initialized_ = false;
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::gpu);
}
int rank() override {
return rank_;
}

View File

@ -4,18 +4,10 @@
#include "mlx/backend/cuda/cuda.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
inline mlx::core::Device get_device() {
if (mlx::core::metal::is_available()) {
return mlx::core::Device::cpu;
} else if (mlx::core::cu::is_available()) {
return mlx::core::Device::gpu;
}
throw std::runtime_error("No available device for distributed operations.");
}
namespace mlx::core::distributed {
namespace {
@ -35,15 +27,16 @@ array all_sum(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
{x});
}
@ -52,15 +45,16 @@ array all_max(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
{x});
}
@ -69,15 +63,16 @@ array all_min(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
{x});
}
@ -86,11 +81,11 @@ array all_gather(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
return x;
}
auto stream = detail::communication_stream(group, s);
auto result_shape = x.shape();
if (result_shape.size() == 0) {
@ -101,7 +96,7 @@ array all_gather(
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(to_stream(s, dev), group),
std::make_shared<AllGather>(stream, group),
{x});
}
@ -111,11 +106,11 @@ array send(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group");
}
auto stream = detail::communication_stream(group, s);
if (dst < 0 || dst >= group.size()) {
std::ostringstream msg;
@ -125,10 +120,7 @@ array send(
}
return array(
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s, dev), group, dst),
{x});
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
}
array recv(
@ -138,11 +130,11 @@ array recv(
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
auto dev = get_device();
if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group");
}
auto stream = detail::communication_stream(group, s);
if (src < 0 || src >= group.size()) {
std::ostringstream msg;
@ -153,7 +145,7 @@ array recv(
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s, dev), group, src),
std::make_shared<Recv>(stream, group, src),
std::vector<array>{});
}

View File

@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
}
}
Stream communication_stream(StreamOrDevice s) override {
return to_stream(s, Device::cpu);
}
int rank() override {
return rank_;
}