mlx/mlx/scheduler.cpp
Awni Hannun 3d405fb3b1
Add synchronize function (#1006)
* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
2024-04-22 08:25:46 -07:00

61 lines
1.6 KiB
C++

// Copyright © 2023 Apple Inc.
#include "mlx/scheduler.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
Stream default_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
return scheduler::scheduler().get_default_stream(d);
}
void set_default_stream(Stream s) {
if (!metal::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
return scheduler::scheduler().set_default_stream(s);
}
Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}
Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
}
void synchronize(Stream s) {
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
} else {
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
}
f.wait();
}
void synchronize() {
synchronize(default_stream(default_device()));
}
namespace scheduler {
/** A singleton scheduler to manage devices, streams, and task execution. */
Scheduler& scheduler() {
static Scheduler scheduler;
return scheduler;
}
} // namespace scheduler
} // namespace mlx::core