Adds send/recv ops in distributed (#1366)

This commit is contained in:
Angelos Katharopoulos
2024-08-26 23:01:37 -07:00
committed by GitHub
parent 1d94ac3f90
commit cdb59faea6
13 changed files with 345 additions and 19 deletions

View File

@@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
DEFINE_PRINT(AllGather);
};
class Send : public DistPrimitive {
public:
Send(Stream stream, Group group, int dst)
: DistPrimitive(stream, group), dst_(dst) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_PRINT(Send);
private:
int dst_;
};
class Recv : public DistPrimitive {
public:
Recv(Stream stream, Group group, int src)
: DistPrimitive(stream, group), src_(src) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Recv);
private:
int src_;
};
} // namespace mlx::core::distributed