mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
Compare commits
3 Commits
v0.12.1
...
priority-q
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5b46b9bc52 | ||
![]() |
fd94be28ea | ||
![]() |
9051fa1eaa |
@@ -18,6 +18,17 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/* This class is only meant to be used in eval
|
||||||
|
* for synchronizing with the main thread. */
|
||||||
|
class Synchronizer : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Synchronizer(Stream stream) : Primitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||||
|
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||||
|
void print(std::ostream&) override {}
|
||||||
|
};
|
||||||
|
|
||||||
// Initialize the static tracing counter from transforms_impl.h .
|
// Initialize the static tracing counter from transforms_impl.h .
|
||||||
//
|
//
|
||||||
// This is used to implement the in_tracing() function the returns true if we
|
// This is used to implement the in_tracing() function the returns true if we
|
||||||
@@ -188,31 +199,44 @@ void simplify(const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void eval(const std::vector<array>& outputs) {
|
void eval(const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&)> recurse;
|
std::function<int(const array&)> recurse;
|
||||||
std::queue<array> tape;
|
std::unordered_map<std::uintptr_t, int> cache;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
|
||||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||||
|
struct ArrayWithPriority {
|
||||||
|
int depth;
|
||||||
|
int order;
|
||||||
|
array x;
|
||||||
|
};
|
||||||
|
auto cmp = [](const ArrayWithPriority& a, const ArrayWithPriority& b) {
|
||||||
|
return (a.depth > b.depth) || (a.depth == b.depth && a.order > b.order);
|
||||||
|
};
|
||||||
|
std::priority_queue<
|
||||||
|
ArrayWithPriority,
|
||||||
|
std::vector<ArrayWithPriority>,
|
||||||
|
decltype(cmp)>
|
||||||
|
tape(cmp);
|
||||||
|
int order;
|
||||||
|
|
||||||
std::set<std::uintptr_t> output_primitives;
|
// Make an effort to choose a good output stream
|
||||||
for (auto& arr : outputs) {
|
Stream stream = default_stream(default_device());
|
||||||
if (!arr.is_evaled()) {
|
for (auto& o : outputs) {
|
||||||
output_primitives.insert(arr.primitive_id());
|
if (!o.is_evaled() && o.has_primitive()) {
|
||||||
|
stream = o.primitive().stream();
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto synchronizer =
|
||||||
|
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
||||||
|
|
||||||
recurse = [&](const array& a) {
|
recurse = [&](const array& a) {
|
||||||
auto id = a.id();
|
auto id = a.id();
|
||||||
if (cache.find(id) != cache.end()) {
|
if (auto it = cache.find(id); it != cache.end()) {
|
||||||
return;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
int input_depth = 0;
|
||||||
for (auto in : a.inputs()) {
|
for (auto in : a.inputs()) {
|
||||||
// Pop fake outputs from the output set so we know who to synchronize
|
input_depth = std::max(input_depth, recurse(in));
|
||||||
// with at the end
|
|
||||||
if (auto it = output_primitives.find(in.primitive_id());
|
|
||||||
it != output_primitives.end()) {
|
|
||||||
output_primitives.erase(it);
|
|
||||||
}
|
|
||||||
recurse(in);
|
|
||||||
// If one of the inputs is being computed on a different
|
// If one of the inputs is being computed on a different
|
||||||
// stream, we need to manage the dependency.
|
// stream, we need to manage the dependency.
|
||||||
if (!in.is_evaled()) {
|
if (!in.is_evaled()) {
|
||||||
@@ -221,37 +245,28 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.insert(id);
|
cache.insert({id, input_depth + 1});
|
||||||
for (auto& s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
cache.insert(s.id());
|
cache.insert({s.id(), input_depth + 1});
|
||||||
}
|
}
|
||||||
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
|
||||||
if (!a.has_primitive()) {
|
if (!a.has_primitive()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[eval] Attempting to eval an array without a primitive.");
|
"[eval] Attempting to eval an array without a primitive.");
|
||||||
}
|
}
|
||||||
tape.push(a);
|
tape.push({input_depth + 1, order++, a});
|
||||||
}
|
}
|
||||||
|
return input_depth + 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We have to store the output primitive ids because the arrays are
|
recurse(synchronizer);
|
||||||
// detached during eval and we need to use them for synchronization
|
uintptr_t synch_id = synchronizer.primitive_id();
|
||||||
// at the end of this function
|
deps.insert({synch_id, std::shared_future<void>{}});
|
||||||
std::vector<std::uintptr_t> output_primitive_ids;
|
|
||||||
for (auto& arr : outputs) {
|
|
||||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
|
||||||
recurse(arr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert output dependencies
|
|
||||||
for (auto pid : output_primitives) {
|
|
||||||
deps.insert({pid, std::shared_future<void>{}});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.front());
|
auto val = std::move(tape.top());
|
||||||
|
auto arr = std::move(val.x);
|
||||||
tape.pop();
|
tape.pop();
|
||||||
if (arr.is_evaled()) {
|
if (arr.is_evaled()) {
|
||||||
if (!arr.is_tracer() && arr.has_primitive()) {
|
if (!arr.is_tracer() && arr.has_primitive()) {
|
||||||
@@ -305,9 +320,8 @@ void eval(const std::vector<array>& outputs) {
|
|||||||
scheduler::enqueue(stream, std::move(task));
|
scheduler::enqueue(stream, std::move(task));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto id : output_primitives) {
|
|
||||||
deps[id].wait();
|
deps[synch_id].wait();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
Reference in New Issue
Block a user