fix and cleanup event signal/wait for metal (#1765)

This commit is contained in:
Awni Hannun 2025-01-10 18:37:26 -08:00 committed by GitHub
parent a4a2764a52
commit 252e423e81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 32 deletions

View File

@ -5,24 +5,18 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h" #include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
namespace mlx::core::distributed { namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out, const Stream& s) { void signal_and_wait(const array& in, const array& out) {
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) { if (in.event().valid()) {
command_buffer->encodeSignalEvent( encode_signal(in.event());
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
} }
command_buffer->encodeWait( encode_wait(out.event());
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
} }
void AllReduce::eval_gpu( void AllReduce::eval_gpu(
@ -58,7 +52,7 @@ void AllReduce::eval_gpu(
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream()); signal_and_wait(in, out);
} }
void AllGather::eval_gpu( void AllGather::eval_gpu(
@ -79,7 +73,7 @@ void AllGather::eval_gpu(
out.event().signal(); out.event().signal();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream()); signal_and_wait(in, out);
} }
void Send::eval_gpu( void Send::eval_gpu(
@ -104,14 +98,8 @@ void Send::eval_gpu(
// Encode a signal event for the input but not a wait since we don't need to // Encode a signal event for the input but not a wait since we don't need to
// wait on the output. // wait on the output.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) { if (in.event().valid()) {
command_buffer->encodeSignalEvent( encode_signal(in.event());
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
} }
} }
@ -133,13 +121,7 @@ void Recv::eval_gpu(
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a wait event as there is no input for the recv to encode a signal. // Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream(); encode_wait(out.event());
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
} }
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@ -6,6 +6,26 @@
namespace mlx::core { namespace mlx::core {
void encode_wait(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
void encode_signal(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
Event::Event(const Stream& stream) : stream_(stream) { Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) { auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool(); auto p = metal::new_scoped_memory_pool();

10
mlx/backend/metal/event.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
void encode_wait(Event e);
void encode_signal(Event e);
} // namespace mlx::core

View File

@ -10,6 +10,7 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
@ -321,12 +322,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.event().signal(); out.event().signal();
}; };
scheduler::enqueue(io_stream(), std::move(signal_task)); scheduler::enqueue(io_stream(), std::move(signal_task));
auto& d = metal::device(stream().device); encode_wait(out.event());
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
} }
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {