mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix and cleanup event signal/wait for metal (#1765)
This commit is contained in:
parent
a4a2764a52
commit
252e423e81
@ -5,24 +5,18 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
void signal_and_wait(const array& in, const array& out) {
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
encode_signal(in.event());
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
encode_wait(out.event());
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
@ -58,7 +52,7 @@ void AllReduce::eval_gpu(
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
signal_and_wait(in, out);
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
@ -79,7 +73,7 @@ void AllGather::eval_gpu(
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
signal_and_wait(in, out);
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
@ -104,14 +98,8 @@ void Send::eval_gpu(
|
||||
|
||||
// Encode a signal event for the input but not a wait since we don't need to
|
||||
// wait on the output.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
encode_signal(in.event());
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,13 +121,7 @@ void Recv::eval_gpu(
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
encode_wait(out.event());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -6,6 +6,26 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void encode_signal(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
|
10
mlx/backend/metal/event.h
Normal file
10
mlx/backend/metal/event.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e);
|
||||
|
||||
void encode_signal(Event e);
|
||||
|
||||
} // namespace mlx::core
|
@ -10,6 +10,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@ -321,12 +322,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||
auto& d = metal::device(stream().device);
|
||||
d.end_encoding(stream().index);
|
||||
auto command_buffer = d.get_command_buffer(stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
encode_wait(out.event());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
Loading…
Reference in New Issue
Block a user