mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:05:21 +08:00
42 lines
1.1 KiB
C++
42 lines
1.1 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/distributed/distributed.h"
|
|
|
|
namespace mlx::core::distributed::detail {
|
|
|
|
/**
|
|
* Abstract base class of a distributed group implementation.
|
|
*/
|
|
class GroupImpl {
|
|
public:
|
|
virtual ~GroupImpl() {}
|
|
|
|
virtual int rank() = 0;
|
|
virtual int size() = 0;
|
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
|
|
|
virtual void all_sum(const array& input, array& output) = 0;
|
|
virtual void all_gather(const array& input, array& output) = 0;
|
|
virtual void send(const array& input, int dst) = 0;
|
|
virtual void recv(array& out, int src) = 0;
|
|
};
|
|
|
|
/* Return the communication stream. */
|
|
Stream communication_stream();
|
|
|
|
/* Perform an all reduce sum operation */
|
|
void all_sum(Group group, const array& input, array& output);
|
|
|
|
/* Perform an all gather operation */
|
|
void all_gather(Group group, const array& input, array& output);
|
|
|
|
/** Send an array to the dst rank */
|
|
void send(Group group, const array& input, int dst);
|
|
|
|
/** Recv an array from the src rank */
|
|
void recv(Group group, array& out, int src);
|
|
|
|
} // namespace mlx::core::distributed::detail
|