mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Add synchronize function (#1006)
* add synchronize function * fix linux * fix linux * fix and fix docs * fix test * try synchronize in stream destroy * synchronize works for both cpu and gpu
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -115,13 +114,31 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
return task;
|
||||
}
|
||||
|
||||
bool start_capture(std::string path, id object) {
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [s, p = std::move(p)]() {
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
if (cb == nullptr) {
|
||||
cb = d.new_command_buffer(s.index);
|
||||
} else {
|
||||
d.end_encoding(s.index);
|
||||
}
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
p->set_value();
|
||||
};
|
||||
}
|
||||
|
||||
void start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
descriptor->setCaptureObject(object);
|
||||
|
||||
if (path.length() > 0) {
|
||||
if (!path.empty()) {
|
||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||
auto url = NS::URL::fileURLWithPath(string);
|
||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||
@@ -129,15 +146,24 @@ bool start_capture(std::string path, id object) {
|
||||
}
|
||||
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
return manager->startCapture(descriptor, nullptr);
|
||||
NS::Error* error;
|
||||
bool started = manager->startCapture(descriptor, &error);
|
||||
descriptor->release();
|
||||
if (!started) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::start_capture] Failed to start: "
|
||||
<< error->localizedDescription()->utf8String();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
bool start_capture(std::string path) {
|
||||
void start_capture(std::string path) {
|
||||
auto& device = metal::device(mlx::core::Device::gpu);
|
||||
return start_capture(path, device.mtl_device());
|
||||
}
|
||||
|
||||
void stop_capture() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
@@ -55,7 +55,7 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
bool start_capture(std::string path = "");
|
||||
void start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -15,4 +16,8 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -22,6 +22,14 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
throw std::runtime_error(
|
||||
"[metal::make_synchronize_task] Cannot synchronize GPU"
|
||||
" without metal backend");
|
||||
}
|
||||
|
||||
// No-ops when Metal is not available.
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
@@ -38,9 +46,7 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
bool start_capture(std::string path) {
|
||||
return false;
|
||||
}
|
||||
void start_capture(std::string path) {}
|
||||
void stop_capture() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -33,6 +33,21 @@ Stream new_stream() {
|
||||
return scheduler::scheduler().new_stream(default_device());
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
||||
} else {
|
||||
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||
}
|
||||
f.wait();
|
||||
}
|
||||
|
||||
void synchronize() {
|
||||
synchronize(default_stream(default_device()));
|
||||
}
|
||||
|
||||
namespace scheduler {
|
||||
|
||||
/** A singleton scheduler to manage devices, streams, and task execution. */
|
||||
|
@@ -27,6 +27,7 @@ struct StreamThread {
|
||||
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
||||
|
||||
~StreamThread() {
|
||||
synchronize(stream);
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mtx);
|
||||
stop = true;
|
||||
|
@@ -29,4 +29,10 @@ inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
/* Synchronize with the default stream. */
|
||||
void synchronize();
|
||||
|
||||
/* Synchronize with the provided stream. */
|
||||
void synchronize(Stream);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user