mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
Put the decision of the comm stream to the group
This commit is contained in:
parent
3bb6b1d44a
commit
eeb5a0d63f
@ -13,6 +13,10 @@ namespace mlx::core::distributed {
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
|
||||||
|
return group.raw_group()->communication_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
void all_sum(Group group, const array& input, array& output, Stream stream) {
|
||||||
group.raw_group()->all_sum(input, output, stream);
|
group.raw_group()->all_sum(input, output, stream);
|
||||||
}
|
}
|
||||||
@ -39,6 +43,10 @@ void recv(Group group, array& out, int src, Stream stream) {
|
|||||||
|
|
||||||
class EmptyGroup : public GroupImpl {
|
class EmptyGroup : public GroupImpl {
|
||||||
public:
|
public:
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
@ -13,10 +13,15 @@ class GroupImpl {
|
|||||||
public:
|
public:
|
||||||
virtual ~GroupImpl() {}
|
virtual ~GroupImpl() {}
|
||||||
|
|
||||||
|
// Choose the stream this communication group can operate on
|
||||||
|
virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
|
||||||
|
|
||||||
|
// Group operations
|
||||||
virtual int rank() = 0;
|
virtual int rank() = 0;
|
||||||
virtual int size() = 0;
|
virtual int size() = 0;
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
||||||
|
|
||||||
|
// Actual communication operations
|
||||||
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
||||||
virtual void send(const array& input, int dst, Stream stream) = 0;
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
||||||
@ -25,6 +30,9 @@ class GroupImpl {
|
|||||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* Define the MLX stream that the communication should happen in. */
|
||||||
|
Stream communication_stream(Group group, StreamOrDevice s = {});
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_sum(Group group, const array& input, array& output, Stream stream);
|
void all_sum(Group group, const array& input, array& output, Stream stream);
|
||||||
|
|
||||||
|
@ -349,6 +349,10 @@ class MPIGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::cpu);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
if (rank_ < 0) {
|
if (rank_ < 0) {
|
||||||
mpi().rank(comm_, &rank_);
|
mpi().rank(comm_, &rank_);
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
@ -255,6 +256,10 @@ class NCCLGroup : public GroupImpl {
|
|||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::gpu);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return rank_;
|
return rank_;
|
||||||
}
|
}
|
||||||
|
@ -4,18 +4,10 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/cuda.h"
|
#include "mlx/backend/cuda/cuda.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
inline mlx::core::Device get_device() {
|
|
||||||
if (mlx::core::metal::is_available()) {
|
|
||||||
return mlx::core::Device::cpu;
|
|
||||||
} else if (mlx::core::cu::is_available()) {
|
|
||||||
return mlx::core::Device::gpu;
|
|
||||||
}
|
|
||||||
throw std::runtime_error("No available device for distributed operations.");
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -35,15 +27,16 @@ array all_sum(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Sum),
|
std::make_shared<AllReduce>(stream, group, AllReduce::Sum),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,15 +45,16 @@ array all_max(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Max),
|
std::make_shared<AllReduce>(stream, group, AllReduce::Max),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,15 +63,16 @@ array all_min(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(to_stream(s, dev), group, AllReduce::Min),
|
std::make_shared<AllReduce>(stream, group, AllReduce::Min),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,11 +81,11 @@ array all_gather(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
auto result_shape = x.shape();
|
auto result_shape = x.shape();
|
||||||
if (result_shape.size() == 0) {
|
if (result_shape.size() == 0) {
|
||||||
@ -101,7 +96,7 @@ array all_gather(
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(to_stream(s, dev), group),
|
std::make_shared<AllGather>(stream, group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,11 +106,11 @@ array send(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot send to a singleton group");
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
if (dst < 0 || dst >= group.size()) {
|
if (dst < 0 || dst >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -125,10 +120,7 @@ array send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(), x.dtype(), std::make_shared<Send>(stream, group, dst), {x});
|
||||||
x.dtype(),
|
|
||||||
std::make_shared<Send>(to_stream(s, dev), group, dst),
|
|
||||||
{x});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array recv(
|
array recv(
|
||||||
@ -138,11 +130,11 @@ array recv(
|
|||||||
std::optional<Group> group_ /* = std::nullopt */,
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
auto dev = get_device();
|
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
}
|
}
|
||||||
|
auto stream = detail::communication_stream(group, s);
|
||||||
|
|
||||||
if (src < 0 || src >= group.size()) {
|
if (src < 0 || src >= group.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -153,7 +145,7 @@ array recv(
|
|||||||
return array(
|
return array(
|
||||||
std::move(shape),
|
std::move(shape),
|
||||||
std::move(dtype),
|
std::move(dtype),
|
||||||
std::make_shared<Recv>(to_stream(s, dev), group, src),
|
std::make_shared<Recv>(stream, group, src),
|
||||||
std::vector<array>{});
|
std::vector<array>{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -619,6 +619,10 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream communication_stream(StreamOrDevice s) override {
|
||||||
|
return to_stream(s, Device::cpu);
|
||||||
|
}
|
||||||
|
|
||||||
int rank() override {
|
int rank() override {
|
||||||
return rank_;
|
return rank_;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user