MPI ops in GPU stream for faster comms (#1356)

This commit is contained in:
Awni Hannun
2024-08-26 15:12:50 -07:00
committed by GitHub
parent 2fdf9eb535
commit 5f7d19d1f5
14 changed files with 220 additions and 26 deletions

View File

@@ -0,0 +1,18 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed::detail {
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);
} // namespace mlx::core::distributed::detail

View File

@@ -5,6 +5,7 @@
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/scheduler.h"
#define LOAD_SYMBOL(symbol, variable) \

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
namespace mlx::core::distributed {

View File

@@ -17,7 +17,10 @@ Group to_group(std::optional<Group> group) {
} // namespace
array all_sum(const array& x, std::optional<Group> group_) {
array all_sum(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(group, AllReduce::Sum),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
{x});
}
array all_gather(const array& x, std::optional<Group> group_) {
array all_gather(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
@@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional<Group> group_) {
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(group),
std::make_shared<AllGather>(to_stream(s), group),
{x});
}

View File

@@ -5,10 +5,17 @@
#include <optional>
#include "mlx/distributed/distributed.h"
#include "mlx/utils.h"
namespace mlx::core::distributed {
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
array all_sum(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_gather(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice S = {});
} // namespace mlx::core::distributed

View File

@@ -3,7 +3,6 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"

View File

@@ -3,20 +3,15 @@
#pragma once
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/primitives.h"
namespace mlx::core::distributed {
class DistPrimitive : public Primitive {
public:
DistPrimitive(Group group)
: Primitive(detail::communication_stream()), group_(group) {}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error(
"Communication primitives cannot be run on the GPU");
}
DistPrimitive(Stream stream, Group group)
: Primitive(stream), group_(group) {}
const Group& group() const {
return group_;
@@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
public:
enum ReduceType { And, Or, Sum, Prod, Min, Max };
AllReduce(Group group, ReduceType reduce_type)
: DistPrimitive(group), reduce_type_(reduce_type) {}
AllReduce(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
class AllGather : public DistPrimitive {
public:
AllGather(Group group) : DistPrimitive(group) {}
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;