fix and cleanup event signal/wait for metal (#1765)

This commit is contained in:
Awni Hannun 2025-01-10 18:37:26 -08:00 committed by GitHub
parent a4a2764a52
commit 252e423e81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 32 deletions

View File

@ -5,24 +5,18 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out, const Stream& s) {
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
void signal_and_wait(const array& in, const array& out) {
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
encode_signal(in.event());
}
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
encode_wait(out.event());
}
void AllReduce::eval_gpu(
@ -58,7 +52,7 @@ void AllReduce::eval_gpu(
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
signal_and_wait(in, out);
}
void AllGather::eval_gpu(
@ -79,7 +73,7 @@ void AllGather::eval_gpu(
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out, stream());
signal_and_wait(in, out);
}
void Send::eval_gpu(
@ -104,14 +98,8 @@ void Send::eval_gpu(
// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
encode_signal(in.event());
}
}
@ -133,13 +121,7 @@ void Recv::eval_gpu(
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
encode_wait(out.event());
}
} // namespace mlx::core::distributed

View File

@ -6,6 +6,26 @@
namespace mlx::core {
void encode_wait(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
void encode_signal(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();

10
mlx/backend/metal/event.h Normal file
View File

@ -0,0 +1,10 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
void encode_wait(Event e);
void encode_signal(Event e);
} // namespace mlx::core

View File

@ -10,6 +10,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
@ -321,12 +322,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.event().signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
auto& d = metal::device(stream().device);
d.end_encoding(stream().index);
auto command_buffer = d.get_command_buffer(stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
encode_wait(out.event());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {