mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
@@ -132,6 +132,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
84
mlx/backend/metal/distributed.cpp
Normal file
84
mlx/backend/metal/distributed.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -47,8 +47,6 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
// TODO, consider committing the buffer and encoding a wait in the new
|
||||
// buffer rather than on the task thread
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@@ -122,4 +123,9 @@ NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
18
mlx/distributed/distributed_impl.h
Normal file
18
mlx/distributed/distributed_impl.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
|
@@ -17,7 +17,10 @@ Group to_group(std::optional<Group> group) {
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_sum(const array& x, std::optional<Group> group_) {
|
||||
array all_sum(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(group),
|
||||
std::make_shared<AllGather>(to_stream(s), group),
|
||||
{x});
|
||||
}
|
||||
|
||||
|
@@ -5,10 +5,17 @@
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_sum(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice S = {});
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
|
@@ -3,20 +3,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
class DistPrimitive : public Primitive {
|
||||
public:
|
||||
DistPrimitive(Group group)
|
||||
: Primitive(detail::communication_stream()), group_(group) {}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error(
|
||||
"Communication primitives cannot be run on the GPU");
|
||||
}
|
||||
DistPrimitive(Stream stream, Group group)
|
||||
: Primitive(stream), group_(group) {}
|
||||
|
||||
const Group& group() const {
|
||||
return group_;
|
||||
@@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
||||
|
||||
AllReduce(Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(group), reduce_type_(reduce_type) {}
|
||||
AllReduce(Stream stream, Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
@@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
|
||||
|
||||
class AllGather : public DistPrimitive {
|
||||
public:
|
||||
AllGather(Group group) : DistPrimitive(group) {}
|
||||
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
@@ -4,10 +4,10 @@
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
#include "stream.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
Reference in New Issue
Block a user