From 5f7d19d1f5dc7bdb6216f0e7043634ef69b448e6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 26 Aug 2024 15:12:50 -0700 Subject: [PATCH] MPI ops in GPU stream for faster comms (#1356) --- benchmarks/python/distributed_bench.py | 66 ++++++++++++++++++++ mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/distributed.cpp | 84 ++++++++++++++++++++++++++ mlx/backend/metal/metal.cpp | 2 - mlx/backend/no_metal/primitives.cpp | 6 ++ mlx/distributed/distributed_impl.h | 18 ++++++ mlx/distributed/mpi/mpi.cpp | 1 + mlx/distributed/no_distributed.cpp | 1 + mlx/distributed/ops.cpp | 14 +++-- mlx/distributed/ops.h | 11 +++- mlx/distributed/primitives.cpp | 1 - mlx/distributed/primitives.h | 22 +++---- mlx/utils.h | 8 +-- python/src/distributed.cpp | 11 +++- 14 files changed, 220 insertions(+), 26 deletions(-) create mode 100644 benchmarks/python/distributed_bench.py create mode 100644 mlx/backend/metal/distributed.cpp create mode 100644 mlx/distributed/distributed_impl.h diff --git a/benchmarks/python/distributed_bench.py b/benchmarks/python/distributed_bench.py new file mode 100644 index 000000000..b6fcef19e --- /dev/null +++ b/benchmarks/python/distributed_bench.py @@ -0,0 +1,66 @@ +# Copyright © 2024 Apple Inc. + +""" +Run with: + mpirun -n 2 python /path/to/distributed_bench.py +""" + +import time + +import mlx.core as mx + + +def time_fn(fn, *args, **kwargs): + msg = kwargs.pop("msg", None) + world = mx.distributed.init() + if world.rank() == 0: + if msg: + print(f"Timing {msg} ...", end=" ") + else: + print(f"Timing {fn.__name__} ...", end=" ") + + # warmup + for _ in range(5): + mx.eval(fn(*args, **kwargs)) + + num_iters = 100 + tic = time.perf_counter() + for _ in range(num_iters): + x = mx.eval(fn(*args, **kwargs)) + toc = time.perf_counter() + + msec = 1e3 * (toc - tic) / num_iters + if world.rank() == 0: + print(f"{msec:.5f} msec") + + +def time_all_sum(): + shape = (4096,) + x = mx.random.uniform(shape=shape) + mx.eval(x) + + def sine(x): + for _ in range(20): + x = mx.sin(x) + return x + + time_fn(sine, x) + + def all_sum_plain(x): + for _ in range(20): + x = mx.distributed.all_sum(x) + return x + + time_fn(all_sum_plain, x) + + def all_sum_with_sine(x): + for _ in range(20): + x = mx.sin(x) + x = mx.distributed.all_sum(x) + return x + + time_fn(all_sum_with_sine, x) + + +if __name__ == "__main__": + time_all_sum() diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 5c470e4cf..d337205da 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -132,6 +132,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp new file mode 100644 index 000000000..64f7c979c --- /dev/null +++ b/mlx/backend/metal/distributed.cpp @@ -0,0 +1,84 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/allocator.h" +#include "mlx/backend/metal/device.h" +#include "mlx/distributed/ops.h" +#include "mlx/distributed/primitives.h" +#include "mlx/scheduler.h" + +namespace mlx::core::distributed { + +void signal_and_wait(const array& in, const array& out, const Stream s) { + auto& d = metal::device(s.device); + d.end_encoding(s.index); + auto command_buffer = d.get_command_buffer(s.index); + if (in.event().valid()) { + command_buffer->encodeSignalEvent( + static_cast(in.event().raw_event().get()), + in.event().value()); + } + command_buffer->encodeWait( + static_cast(out.event().raw_event().get()), + out.event().value()); +} + +void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& in = inputs[0]; + auto& out = outputs[0]; + if (in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto task = [in = in, + out = out, + reduce_type = reduce_type_, + group = group()]() mutable { + if (in.event().valid()) { + in.event().wait(); + } + switch (reduce_type) { + case Sum: + distributed::detail::all_sum( + group, in.data_shared_ptr() == nullptr ? out : in, out); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } + out.event().signal(); + }; + scheduler::enqueue(detail::communication_stream(), std::move(task)); + + signal_and_wait(in, out, stream()); +} + +void AllGather::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto& in = inputs[0]; + auto& out = outputs[0]; + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto task = [in = in, out = out, group = group()]() mutable { + if (in.event().valid()) { + in.event().wait(); + } + distributed::detail::all_gather(group, in, out); + out.event().signal(); + }; + scheduler::enqueue(detail::communication_stream(), std::move(task)); + signal_and_wait(in, out, stream()); +} + +} // namespace mlx::core::distributed diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 0bdb60ede..8755e73f1 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -47,8 +47,6 @@ std::function make_task(array arr, bool signal) { for (auto& input : arr.inputs()) { if (input.event().valid() && input.event().stream() != arr.primitive().stream()) { - // TODO, consider committing the buffer and encoding a wait in the new - // buffer rather than on the task thread input.event().wait(); } } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index ac9e6ca68..a98313e5d 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/distributed/primitives.h" #include "mlx/fast_primitives.h" #define NO_GPU_MULTI(func) \ @@ -122,4 +123,9 @@ NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(CustomKernel) } // namespace fast +namespace distributed { +NO_GPU_MULTI(AllReduce) +NO_GPU_MULTI(AllGather) +} // namespace distributed + } // namespace mlx::core diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h new file mode 100644 index 000000000..42fd5aac3 --- /dev/null +++ b/mlx/distributed/distributed_impl.h @@ -0,0 +1,18 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::detail { + +/* Return the communication stream. */ +Stream communication_stream(); + +/* Perform an all reduce sum operation */ +void all_sum(Group group, const array& input, array& output); + +/* Perform an all reduce sum operation */ +void all_gather(Group group, const array& input, array& output); + +} // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 4ea4d8573..5c60f9e97 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -5,6 +5,7 @@ #include "mlx/backend/common/copy.h" #include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" #include "mlx/scheduler.h" #define LOAD_SYMBOL(symbol, variable) \ diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index fcf346ad8..9c3e19227 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" namespace mlx::core::distributed { diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 67f03cd31..54f8f483b 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -17,7 +17,10 @@ Group to_group(std::optional group) { } // namespace -array all_sum(const array& x, std::optional group_) { +array all_sum( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { @@ -27,11 +30,14 @@ array all_sum(const array& x, std::optional group_) { return array( x.shape(), x.dtype(), - std::make_shared(group, AllReduce::Sum), + std::make_shared(to_stream(s), group, AllReduce::Sum), {x}); } -array all_gather(const array& x, std::optional group_) { +array all_gather( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { auto group = to_group(group_); if (group.size() == 1) { @@ -47,7 +53,7 @@ array all_gather(const array& x, std::optional group_) { return array( std::move(result_shape), x.dtype(), - std::make_shared(group), + std::make_shared(to_stream(s), group), {x}); } diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index bc3fab08d..85e9e99cc 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -5,10 +5,17 @@ #include #include "mlx/distributed/distributed.h" +#include "mlx/utils.h" namespace mlx::core::distributed { -array all_sum(const array& x, std::optional group = std::nullopt); -array all_gather(const array& x, std::optional group = std::nullopt); +array all_sum( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); +array all_gather( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice S = {}); } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index c4b786787..a115ea087 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -3,7 +3,6 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/ops.h" diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 8107f4b12..4bf40b41c 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -3,20 +3,15 @@ #pragma once #include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" #include "mlx/primitives.h" namespace mlx::core::distributed { class DistPrimitive : public Primitive { public: - DistPrimitive(Group group) - : Primitive(detail::communication_stream()), group_(group) {} - - void eval_gpu(const std::vector& inputs, std::vector& outputs) - override { - throw std::runtime_error( - "Communication primitives cannot be run on the GPU"); - } + DistPrimitive(Stream stream, Group group) + : Primitive(stream), group_(group) {} const Group& group() const { return group_; @@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive { public: enum ReduceType { And, Or, Sum, Prod, Min, Max }; - AllReduce(Group group, ReduceType reduce_type) - : DistPrimitive(group), reduce_type_(reduce_type) {} + AllReduce(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; @@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive { class AllGather : public DistPrimitive { public: - AllGather(Group group) : DistPrimitive(group) {} + AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; diff --git a/mlx/utils.h b/mlx/utils.h index 9f39f5092..e536da55f 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -4,10 +4,10 @@ #include -#include "array.h" -#include "device.h" -#include "dtype.h" -#include "stream.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/dtype.h" +#include "mlx/stream.h" namespace mlx::core { diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index e0a11a4fc..6395a642b 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" @@ -74,8 +75,9 @@ void init_distributed(nb::module_& parent_module) { "x"_a, nb::kw_only(), "group"_a = nb::none(), + "stream"_a = nb::none(), nb::sig( - "def all_sum(x: array, *, group: Optional[Group] = None) -> array"), + "def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( All reduce sum. @@ -86,6 +88,8 @@ void init_distributed(nb::module_& parent_module) { group (Group): The group of processes that will participate in the reduction. If set to ``None`` the global group is used. Default: ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. Returns: array: The sum of all ``x`` arrays. @@ -97,8 +101,9 @@ void init_distributed(nb::module_& parent_module) { "x"_a, nb::kw_only(), "group"_a = nb::none(), + "stream"_a = nb::none(), nb::sig( - "def all_gather(x: array, *, group: Optional[Group] = None) -> array"), + "def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Gather arrays from all processes. @@ -110,6 +115,8 @@ void init_distributed(nb::module_& parent_module) { group (Group): The group of processes that will participate in the gather. If set to ``None`` the global group is used. Default: ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. Returns: array: The concatenation of all ``x`` arrays.