mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
@@ -132,6 +132,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
84
mlx/backend/metal/distributed.cpp
Normal file
84
mlx/backend/metal/distributed.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -47,8 +47,6 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
// TODO, consider committing the buffer and encoding a wait in the new
|
||||
// buffer rather than on the task thread
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user