2024-08-27 06:12:50 +08:00
|
|
|
// Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "mlx/distributed/distributed.h"
|
|
|
|
|
|
|
|
namespace mlx::core::distributed::detail {
|
|
|
|
|
2025-01-07 09:33:15 +08:00
|
|
|
/**
|
|
|
|
* Abstract base class of a distributed group implementation.
|
|
|
|
*/
|
|
|
|
class GroupImpl {
|
|
|
|
public:
|
2025-01-28 14:15:01 +08:00
|
|
|
virtual ~GroupImpl() {}
|
|
|
|
|
2025-01-07 09:33:15 +08:00
|
|
|
virtual int rank() = 0;
|
|
|
|
virtual int size() = 0;
|
|
|
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
|
|
|
|
2025-03-07 11:23:38 +08:00
|
|
|
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
|
|
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
|
|
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
|
|
|
virtual void recv(array& out, int src, Stream stream) = 0;
|
2025-01-07 09:33:15 +08:00
|
|
|
};
|
|
|
|
|
2024-08-27 06:12:50 +08:00
|
|
|
/* Perform an all reduce sum operation */
|
2025-03-07 11:23:38 +08:00
|
|
|
void all_sum(Group group, const array& input, array& output, Stream stream);
|
2024-08-27 06:12:50 +08:00
|
|
|
|
2024-08-27 14:01:37 +08:00
|
|
|
/* Perform an all gather operation */
|
2025-03-07 11:23:38 +08:00
|
|
|
void all_gather(Group group, const array& input, array& output, Stream stream);
|
2024-08-27 06:12:50 +08:00
|
|
|
|
2024-08-27 14:01:37 +08:00
|
|
|
/** Send an array to the dst rank */
|
2025-03-07 11:23:38 +08:00
|
|
|
void send(Group group, const array& input, int dst, Stream stream);
|
2024-08-27 14:01:37 +08:00
|
|
|
|
|
|
|
/** Recv an array from the src rank */
|
2025-03-07 11:23:38 +08:00
|
|
|
void recv(Group group, array& out, int src, Stream stream);
|
2024-08-27 14:01:37 +08:00
|
|
|
|
2024-08-27 06:12:50 +08:00
|
|
|
} // namespace mlx::core::distributed::detail
|