mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
This commit is contained in:
@@ -26,6 +26,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
|
@@ -544,11 +544,12 @@ Device& device(mlx::core::Device) {
|
||||
return metal_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
auto dtor = [](void* ptr) {
|
||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||
};
|
||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
return std::unique_ptr<void, std::function<void(void*)>>(
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
|
30
mlx/backend/metal/event.cpp
Normal file
30
mlx/backend/metal/event.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
event_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
|
||||
->waitUntilSignaledValue(value(), -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -55,17 +55,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
auto task = [arr = std::move(arr), signal]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
// TODO, consider committing the buffer and encoding a wait in the new
|
||||
// buffer rather than on the task thread
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
@@ -88,13 +91,16 @@ std::function<void()> make_task(
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
if (p) {
|
||||
|
||||
if (signal) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
||||
arr.event().value());
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers), p = std::move(p)](
|
||||
[s, buffers = std::move(buffers), event = arr.event()](
|
||||
MTL::CommandBuffer* cbuf) {
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
|
@@ -2,9 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
@@ -12,11 +10,9 @@
|
||||
namespace mlx::core::metal {
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
Reference in New Issue
Block a user