mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix and cleanup event signal/wait for metal (#1765)
This commit is contained in:
parent
a4a2764a52
commit
252e423e81
@ -5,24 +5,18 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/event.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
void signal_and_wait(const array& in, const array& out) {
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
command_buffer->encodeSignalEvent(
|
encode_signal(in.event());
|
||||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
|
||||||
in.event().value());
|
|
||||||
}
|
}
|
||||||
command_buffer->encodeWait(
|
encode_wait(out.event());
|
||||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
|
||||||
out.event().value());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduce::eval_gpu(
|
void AllReduce::eval_gpu(
|
||||||
@ -58,7 +52,7 @@ void AllReduce::eval_gpu(
|
|||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
signal_and_wait(in, out, stream());
|
signal_and_wait(in, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllGather::eval_gpu(
|
void AllGather::eval_gpu(
|
||||||
@ -79,7 +73,7 @@ void AllGather::eval_gpu(
|
|||||||
out.event().signal();
|
out.event().signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
signal_and_wait(in, out, stream());
|
signal_and_wait(in, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Send::eval_gpu(
|
void Send::eval_gpu(
|
||||||
@ -104,14 +98,8 @@ void Send::eval_gpu(
|
|||||||
|
|
||||||
// Encode a signal event for the input but not a wait since we don't need to
|
// Encode a signal event for the input but not a wait since we don't need to
|
||||||
// wait on the output.
|
// wait on the output.
|
||||||
auto& s = stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
command_buffer->encodeSignalEvent(
|
encode_signal(in.event());
|
||||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
|
||||||
in.event().value());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,13 +121,7 @@ void Recv::eval_gpu(
|
|||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||||
auto& s = stream();
|
encode_wait(out.event());
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
command_buffer->encodeWait(
|
|
||||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
|
||||||
out.event().value());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -6,6 +6,26 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void encode_wait(Event e) {
|
||||||
|
auto& d = metal::device(e.stream().device);
|
||||||
|
d.end_encoding(e.stream().index);
|
||||||
|
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||||
|
command_buffer->encodeWait(
|
||||||
|
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||||
|
}
|
||||||
|
|
||||||
|
void encode_signal(Event e) {
|
||||||
|
auto& d = metal::device(e.stream().device);
|
||||||
|
d.end_encoding(e.stream().index);
|
||||||
|
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||||
|
command_buffer->encodeSignalEvent(
|
||||||
|
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||||
|
}
|
||||||
|
|
||||||
Event::Event(const Stream& stream) : stream_(stream) {
|
Event::Event(const Stream& stream) : stream_(stream) {
|
||||||
auto dtor = [](void* ptr) {
|
auto dtor = [](void* ptr) {
|
||||||
auto p = metal::new_scoped_memory_pool();
|
auto p = metal::new_scoped_memory_pool();
|
||||||
|
10
mlx/backend/metal/event.h
Normal file
10
mlx/backend/metal/event.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void encode_wait(Event e);
|
||||||
|
|
||||||
|
void encode_signal(Event e);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -10,6 +10,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/event.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/slicing.h"
|
#include "mlx/backend/metal/slicing.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -321,12 +322,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.event().signal();
|
out.event().signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||||
auto& d = metal::device(stream().device);
|
encode_wait(out.event());
|
||||||
d.end_encoding(stream().index);
|
|
||||||
auto command_buffer = d.get_command_buffer(stream().index);
|
|
||||||
command_buffer->encodeWait(
|
|
||||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
|
||||||
out.event().value());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
Loading…
Reference in New Issue
Block a user