MPI ops in GPU stream for faster comms (#1356)

This commit is contained in:
Awni Hannun
2024-08-26 15:12:50 -07:00
committed by GitHub
parent 2fdf9eb535
commit 5f7d19d1f5
14 changed files with 220 additions and 26 deletions

View File

@@ -3,20 +3,15 @@
#pragma once
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/primitives.h"
namespace mlx::core::distributed {
class DistPrimitive : public Primitive {
public:
DistPrimitive(Group group)
: Primitive(detail::communication_stream()), group_(group) {}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error(
"Communication primitives cannot be run on the GPU");
}
DistPrimitive(Stream stream, Group group)
: Primitive(stream), group_(group) {}
const Group& group() const {
return group_;
@@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
public:
enum ReduceType { And, Or, Sum, Prod, Min, Max };
AllReduce(Group group, ReduceType reduce_type)
: DistPrimitive(group), reduce_type_(reduce_type) {}
AllReduce(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
class AllGather : public DistPrimitive {
public:
AllGather(Group group) : DistPrimitive(group) {}
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;