mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
MPI ops in GPU stream for faster comms (#1356)
This commit is contained in:
parent
2fdf9eb535
commit
5f7d19d1f5
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Run with:
|
||||||
|
mpirun -n 2 python /path/to/distributed_bench.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, *args, **kwargs):
|
||||||
|
msg = kwargs.pop("msg", None)
|
||||||
|
world = mx.distributed.init()
|
||||||
|
if world.rank() == 0:
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(*args, **kwargs))
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
x = mx.eval(fn(*args, **kwargs))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
|
if world.rank() == 0:
|
||||||
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def time_all_sum():
|
||||||
|
shape = (4096,)
|
||||||
|
x = mx.random.uniform(shape=shape)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(sine, x)
|
||||||
|
|
||||||
|
def all_sum_plain(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_plain, x)
|
||||||
|
|
||||||
|
def all_sum_with_sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_with_sine, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_all_sum()
|
@ -132,6 +132,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
84
mlx/backend/metal/distributed.cpp
Normal file
84
mlx/backend/metal/distributed.cpp
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/distributed/ops.h"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
void signal_and_wait(const array& in, const array& out, const Stream s) {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
if (in.event().valid()) {
|
||||||
|
command_buffer->encodeSignalEvent(
|
||||||
|
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||||
|
in.event().value());
|
||||||
|
}
|
||||||
|
command_buffer->encodeWait(
|
||||||
|
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||||
|
out.event().value());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllReduce::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto task = [in = in,
|
||||||
|
out = out,
|
||||||
|
reduce_type = reduce_type_,
|
||||||
|
group = group()]() mutable {
|
||||||
|
if (in.event().valid()) {
|
||||||
|
in.event().wait();
|
||||||
|
}
|
||||||
|
switch (reduce_type) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(
|
||||||
|
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
|
}
|
||||||
|
out.event().signal();
|
||||||
|
};
|
||||||
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
|
signal_and_wait(in, out, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGather::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto task = [in = in, out = out, group = group()]() mutable {
|
||||||
|
if (in.event().valid()) {
|
||||||
|
in.event().wait();
|
||||||
|
}
|
||||||
|
distributed::detail::all_gather(group, in, out);
|
||||||
|
out.event().signal();
|
||||||
|
};
|
||||||
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
signal_and_wait(in, out, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed
|
@ -47,8 +47,6 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
for (auto& input : arr.inputs()) {
|
for (auto& input : arr.inputs()) {
|
||||||
if (input.event().valid() &&
|
if (input.event().valid() &&
|
||||||
input.event().stream() != arr.primitive().stream()) {
|
input.event().stream() != arr.primitive().stream()) {
|
||||||
// TODO, consider committing the buffer and encoding a wait in the new
|
|
||||||
// buffer rather than on the task thread
|
|
||||||
input.event().wait();
|
input.event().wait();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
@ -122,4 +123,9 @@ NO_GPU_MULTI(AffineQuantize)
|
|||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|
||||||
|
namespace distributed {
|
||||||
|
NO_GPU_MULTI(AllReduce)
|
||||||
|
NO_GPU_MULTI(AllGather)
|
||||||
|
} // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
18
mlx/distributed/distributed_impl.h
Normal file
18
mlx/distributed/distributed_impl.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed::detail {
|
||||||
|
|
||||||
|
/* Return the communication stream. */
|
||||||
|
Stream communication_stream();
|
||||||
|
|
||||||
|
/* Perform an all reduce sum operation */
|
||||||
|
void all_sum(Group group, const array& input, array& output);
|
||||||
|
|
||||||
|
/* Perform an all reduce sum operation */
|
||||||
|
void all_gather(Group group, const array& input, array& output);
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed::detail
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
#define LOAD_SYMBOL(symbol, variable) \
|
#define LOAD_SYMBOL(symbol, variable) \
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
@ -17,7 +17,10 @@ Group to_group(std::optional<Group> group) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array all_sum(const array& x, std::optional<Group> group_) {
|
array all_sum(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
|
|||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
array all_gather(const array& x, std::optional<Group> group_) {
|
array all_gather(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
auto group = to_group(group_);
|
auto group = to_group(group_);
|
||||||
|
|
||||||
if (group.size() == 1) {
|
if (group.size() == 1) {
|
||||||
@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional<Group> group_) {
|
|||||||
return array(
|
return array(
|
||||||
std::move(result_shape),
|
std::move(result_shape),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_shared<AllGather>(group),
|
std::make_shared<AllGather>(to_stream(s), group),
|
||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,10 +5,17 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
array all_sum(
|
||||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
const array& x,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array all_gather(
|
||||||
|
const array& x,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice S = {});
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
@ -3,20 +3,15 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
class DistPrimitive : public Primitive {
|
class DistPrimitive : public Primitive {
|
||||||
public:
|
public:
|
||||||
DistPrimitive(Group group)
|
DistPrimitive(Stream stream, Group group)
|
||||||
: Primitive(detail::communication_stream()), group_(group) {}
|
: Primitive(stream), group_(group) {}
|
||||||
|
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
|
||||||
override {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Communication primitives cannot be run on the GPU");
|
|
||||||
}
|
|
||||||
|
|
||||||
const Group& group() const {
|
const Group& group() const {
|
||||||
return group_;
|
return group_;
|
||||||
@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
|
|||||||
public:
|
public:
|
||||||
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
||||||
|
|
||||||
AllReduce(Group group, ReduceType reduce_type)
|
AllReduce(Stream stream, Group group, ReduceType reduce_type)
|
||||||
: DistPrimitive(group), reduce_type_(reduce_type) {}
|
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
|
|||||||
|
|
||||||
class AllGather : public DistPrimitive {
|
class AllGather : public DistPrimitive {
|
||||||
public:
|
public:
|
||||||
AllGather(Group group) : DistPrimitive(group) {}
|
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
|
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include "array.h"
|
#include "mlx/array.h"
|
||||||
#include "device.h"
|
#include "mlx/device.h"
|
||||||
#include "dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/shared_ptr.h>
|
#include <nanobind/stl/shared_ptr.h>
|
||||||
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
@ -74,8 +75,9 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
All reduce sum.
|
All reduce sum.
|
||||||
|
|
||||||
@ -86,6 +88,8 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
group (Group): The group of processes that will participate in the
|
group (Group): The group of processes that will participate in the
|
||||||
reduction. If set to ``None`` the global group is used. Default:
|
reduction. If set to ``None`` the global group is used. Default:
|
||||||
``None``.
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The sum of all ``x`` arrays.
|
array: The sum of all ``x`` arrays.
|
||||||
@ -97,8 +101,9 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Gather arrays from all processes.
|
Gather arrays from all processes.
|
||||||
|
|
||||||
@ -110,6 +115,8 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
group (Group): The group of processes that will participate in the
|
group (Group): The group of processes that will participate in the
|
||||||
gather. If set to ``None`` the global group is used. Default:
|
gather. If set to ``None`` the global group is used. Default:
|
||||||
``None``.
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The concatenation of all ``x`` arrays.
|
array: The concatenation of all ``x`` arrays.
|
||||||
|
Loading…
Reference in New Issue
Block a user