diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 98d484c10..517043597 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -5,24 +5,18 @@ #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/scheduler.h" namespace mlx::core::distributed { -void signal_and_wait(const array& in, const array& out, const Stream& s) { - auto& d = metal::device(s.device); - d.end_encoding(s.index); - auto command_buffer = d.get_command_buffer(s.index); +void signal_and_wait(const array& in, const array& out) { if (in.event().valid()) { - command_buffer->encodeSignalEvent( - static_cast(in.event().raw_event().get()), - in.event().value()); + encode_signal(in.event()); } - command_buffer->encodeWait( - static_cast(out.event().raw_event().get()), - out.event().value()); + encode_wait(out.event()); } void AllReduce::eval_gpu( @@ -58,7 +52,7 @@ void AllReduce::eval_gpu( }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - signal_and_wait(in, out, stream()); + signal_and_wait(in, out); } void AllGather::eval_gpu( @@ -79,7 +73,7 @@ void AllGather::eval_gpu( out.event().signal(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - signal_and_wait(in, out, stream()); + signal_and_wait(in, out); } void Send::eval_gpu( @@ -104,14 +98,8 @@ void Send::eval_gpu( // Encode a signal event for the input but not a wait since we don't need to // wait on the output. - auto& s = stream(); - auto& d = metal::device(s.device); - d.end_encoding(s.index); - auto command_buffer = d.get_command_buffer(s.index); if (in.event().valid()) { - command_buffer->encodeSignalEvent( - static_cast(in.event().raw_event().get()), - in.event().value()); + encode_signal(in.event()); } } @@ -133,13 +121,7 @@ void Recv::eval_gpu( scheduler::enqueue(detail::communication_stream(), std::move(task)); // Encode a wait event as there is no input for the recv to encode a signal. - auto& s = stream(); - auto& d = metal::device(s.device); - d.end_encoding(s.index); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->encodeWait( - static_cast(out.event().raw_event().get()), - out.event().value()); + encode_wait(out.event()); } } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index be29d4533..e4b78985e 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -6,6 +6,26 @@ namespace mlx::core { +void encode_wait(Event e) { + auto& d = metal::device(e.stream().device); + d.end_encoding(e.stream().index); + auto command_buffer = d.get_command_buffer(e.stream().index); + command_buffer->encodeWait( + static_cast(e.raw_event().get()), e.value()); + command_buffer->addCompletedHandler( + [e = std::move(e)](MTL::CommandBuffer* cbuf) {}); +} + +void encode_signal(Event e) { + auto& d = metal::device(e.stream().device); + d.end_encoding(e.stream().index); + auto command_buffer = d.get_command_buffer(e.stream().index); + command_buffer->encodeSignalEvent( + static_cast(e.raw_event().get()), e.value()); + command_buffer->addCompletedHandler( + [e = std::move(e)](MTL::CommandBuffer* cbuf) {}); +} + Event::Event(const Stream& stream) : stream_(stream) { auto dtor = [](void* ptr) { auto p = metal::new_scoped_memory_pool(); diff --git a/mlx/backend/metal/event.h b/mlx/backend/metal/event.h new file mode 100644 index 000000000..b7f29e73d --- /dev/null +++ b/mlx/backend/metal/event.h @@ -0,0 +1,10 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +namespace mlx::core { + +void encode_wait(Event e); + +void encode_signal(Event e); + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e12d67d5f..3a2491e2d 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -10,6 +10,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" @@ -321,12 +322,7 @@ void Load::eval_gpu(const std::vector& inputs, array& out) { out.event().signal(); }; scheduler::enqueue(io_stream(), std::move(signal_task)); - auto& d = metal::device(stream().device); - d.end_encoding(stream().index); - auto command_buffer = d.get_command_buffer(stream().index); - command_buffer->encodeWait( - static_cast(out.event().raw_event().get()), - out.event().value()); + encode_wait(out.event()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) {