mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 18:51:14 +08:00

* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
39 lines
1.1 KiB
C++
39 lines
1.1 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include "mlx/distributed/distributed.h"
|
|
|
|
namespace mlx::core::distributed::detail {
|
|
|
|
/**
|
|
* Abstract base class of a distributed group implementation.
|
|
*/
|
|
class GroupImpl {
|
|
public:
|
|
virtual ~GroupImpl() {}
|
|
|
|
virtual int rank() = 0;
|
|
virtual int size() = 0;
|
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
|
|
|
|
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
|
|
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
|
|
virtual void send(const array& input, int dst, Stream stream) = 0;
|
|
virtual void recv(array& out, int src, Stream stream) = 0;
|
|
};
|
|
|
|
/* Perform an all reduce sum operation */
|
|
void all_sum(Group group, const array& input, array& output, Stream stream);
|
|
|
|
/* Perform an all gather operation */
|
|
void all_gather(Group group, const array& input, array& output, Stream stream);
|
|
|
|
/** Send an array to the dst rank */
|
|
void send(Group group, const array& input, int dst, Stream stream);
|
|
|
|
/** Recv an array from the src rank */
|
|
void recv(Group group, array& out, int src, Stream stream);
|
|
|
|
} // namespace mlx::core::distributed::detail
|