mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00

* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
65 lines
1.6 KiB
C++
65 lines
1.6 KiB
C++
// Copyright © 2023 Apple Inc.
|
|
|
|
#include "mlx/scheduler.h"
|
|
#include "mlx/backend/metal/metal.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
Stream default_stream(Device d) {
|
|
if (!metal::is_available() && d == Device::gpu) {
|
|
throw std::invalid_argument(
|
|
"[default_stream] Cannot get gpu stream without gpu backend.");
|
|
}
|
|
return scheduler::scheduler().get_default_stream(d);
|
|
}
|
|
|
|
void set_default_stream(Stream s) {
|
|
if (!metal::is_available() && s.device == Device::gpu) {
|
|
throw std::invalid_argument(
|
|
"[set_default_stream] Cannot set gpu stream without gpu backend.");
|
|
}
|
|
return scheduler::scheduler().set_default_stream(s);
|
|
}
|
|
|
|
Stream get_stream(int index) {
|
|
return scheduler::scheduler().get_stream(index);
|
|
}
|
|
|
|
Stream new_stream(Device d) {
|
|
if (!metal::is_available() && d == Device::gpu) {
|
|
throw std::invalid_argument(
|
|
"[new_stream] Cannot make gpu stream without gpu backend.");
|
|
}
|
|
return scheduler::scheduler().new_stream(d);
|
|
}
|
|
|
|
Stream new_stream() {
|
|
return scheduler::scheduler().new_stream(default_device());
|
|
}
|
|
|
|
void synchronize(Stream s) {
|
|
auto p = std::make_shared<std::promise<void>>();
|
|
std::future<void> f = p->get_future();
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
|
} else {
|
|
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
|
}
|
|
f.wait();
|
|
}
|
|
|
|
void synchronize() {
|
|
synchronize(default_stream(default_device()));
|
|
}
|
|
|
|
namespace scheduler {
|
|
|
|
/** A singleton scheduler to manage devices, streams, and task execution. */
|
|
Scheduler& scheduler() {
|
|
static Scheduler scheduler;
|
|
return scheduler;
|
|
}
|
|
|
|
} // namespace scheduler
|
|
} // namespace mlx::core
|