mlx/mlx/scheduler.cpp
Awni Hannun c4230747a1
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00

65 lines
1.6 KiB
C++

// Copyright © 2023 Apple Inc.
#include "mlx/scheduler.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
Stream default_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[default_stream] Cannot get gpu stream without gpu backend.");
}
return scheduler::scheduler().get_default_stream(d);
}
void set_default_stream(Stream s) {
if (!metal::is_available() && s.device == Device::gpu) {
throw std::invalid_argument(
"[set_default_stream] Cannot set gpu stream without gpu backend.");
}
return scheduler::scheduler().set_default_stream(s);
}
Stream get_stream(int index) {
return scheduler::scheduler().get_stream(index);
}
Stream new_stream(Device d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[new_stream] Cannot make gpu stream without gpu backend.");
}
return scheduler::scheduler().new_stream(d);
}
Stream new_stream() {
return scheduler::scheduler().new_stream(default_device());
}
void synchronize(Stream s) {
if (s.device == mlx::core::Device::cpu) {
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
f.wait();
} else {
metal::synchronize(s);
}
}
void synchronize() {
synchronize(default_stream(default_device()));
}
namespace scheduler {
/** A singleton scheduler to manage devices, streams, and task execution. */
Scheduler& scheduler() {
static Scheduler scheduler;
return scheduler;
}
} // namespace scheduler
} // namespace mlx::core