mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Put the decision of the comm stream to the group
This commit is contained in:
parent
3bb6b1d44a
commit
eeb5a0d63f
@ -13,6 +13,10 @@ namespace mlx::core::distributed {
|
||||
|
||||
namespace detail {
|
||||
|
||||
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
|
||||
return group.raw_group()->communication_stream(s);
|
||||
}
|
||||
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||
group.raw_group()->all_sum(input, output, stream);
|
||||
}
|
||||
@ -39,6 +43,10 @@ void recv(Group group, array& out, int src, Stream stream) {
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
public:
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return 0;
|
||||
}
|
||||
|
@ -3,7 +3,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
|
@ -13,10 +13,15 @@ class GroupImpl {
|
||||
public:
|
||||
virtual ~GroupImpl() {}
|
||||
|
||||
// Choose the stream this communication group can operate on
|
||||
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
|
||||
|
||||
// Group operations
|
||||
virtual int rank() = 0;
|
||||
virtual int size() = 0;
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||
|
||||
// Actual communication operations
|
||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||
@ -25,6 +30,9 @@ class GroupImpl {
|
||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Define the MLX stream that the communication should happen in. */
|
||||
Stream communication_stream(Group group, StreamOrDevice s = {});
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
|
@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
if (rank_ < 0) {
|
||||
mpi().rank(comm_, &rank_);
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
@ -255,6 +256,10 @@ class NCCLGroup : public GroupImpl {
|
||||
initialized_ = false;
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::gpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
@ -4,18 +4,10 @@
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
inline mlx::core::Device get_device() {
|
||||
if (mlx::core::metal::is_available()) {
|
||||
return mlx::core::Device::cpu;
|
||||
} else if (mlx::core::cu::is_available()) {
|
||||
return mlx::core::Device::gpu;
|
||||
}
|
||||
throw std::runtime_error("No available device for distributed operations.");
|
||||
}
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
namespace {
|
||||
@ -35,15 +27,16 @@ array all_sum(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -52,15 +45,16 @@ array all_max(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -69,15 +63,16 @@ array all_min(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
|
||||
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -86,11 +81,11 @@ array all_gather(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
auto result_shape = x.shape();
|
||||
if (result_shape.size() == 0) {
|
||||
@ -101,7 +96,7 @@ array all_gather(
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(to_stream(s, dev), group),
|
||||
std::make_shared<AllGather>(stream, group),
|
||||
{x});
|
||||
}
|
||||
|
||||
@ -111,11 +106,11 @@ array send(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot send to a singleton group");
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
if (dst < 0 || dst >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
@ -125,10 +120,7 @@ array send(
|
||||
}
|
||||
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<Send>(to_stream(s, dev), group, dst),
|
||||
{x});
|
||||
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
|
||||
}
|
||||
|
||||
array recv(
|
||||
@ -138,11 +130,11 @@ array recv(
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
auto dev = get_device();
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||
}
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
if (src < 0 || src >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
@ -153,7 +145,7 @@ array recv(
|
||||
return array(
|
||||
std::move(shape),
|
||||
std::move(dtype),
|
||||
std::make_shared<Recv>(to_stream(s, dev), group, src),
|
||||
std::make_shared<Recv>(stream, group, src),
|
||||
std::vector<array>{});
|
||||
}
|
||||
|
||||
|
@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
}
|
||||
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
return to_stream(s, Device::cpu);
|
||||
}
|
||||
|
||||
int rank() override {
|
||||
return rank_;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user