mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Use a dummy primitive to only sync with one output (#453)
* Use a dummy primitive to only sync with one output * Fix test and choose stream with slight care
This commit is contained in:
parent
41cc7bdfdb
commit
4bc446be08
@ -18,6 +18,17 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* This class is only meant to be used in eval
|
||||
* for synchronizing with the main thread. */
|
||||
class Synchronizer : public Primitive {
|
||||
public:
|
||||
explicit Synchronizer(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void print(std::ostream&) override {}
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
@ -193,25 +204,24 @@ void eval(const std::vector<array>& outputs) {
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||
|
||||
std::set<std::uintptr_t> output_primitives;
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled()) {
|
||||
output_primitives.insert(arr.primitive_id());
|
||||
// Make an effort to choose a good output stream
|
||||
Stream stream = default_stream(default_device());
|
||||
for (auto& o : outputs) {
|
||||
if (!o.is_evaled() && o.has_primitive()) {
|
||||
stream = o.primitive().stream();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto synchronizer =
|
||||
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (auto in : a.inputs()) {
|
||||
// Pop fake outputs from the output set so we know who to synchronize
|
||||
// with at the end
|
||||
if (auto it = output_primitives.find(in.primitive_id());
|
||||
it != output_primitives.end()) {
|
||||
output_primitives.erase(it);
|
||||
}
|
||||
recurse(in);
|
||||
// If one of the inputs is being computed on a different
|
||||
// stream, we need to manage the dependency.
|
||||
@ -234,20 +244,9 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
};
|
||||
|
||||
// We have to store the output primitive ids because the arrays are
|
||||
// detached during eval and we need to use them for synchronization
|
||||
// at the end of this function
|
||||
std::vector<std::uintptr_t> output_primitive_ids;
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled() || (!arr.is_tracer() && arr.has_primitive())) {
|
||||
recurse(arr);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert output dependencies
|
||||
for (auto pid : output_primitives) {
|
||||
deps.insert({pid, std::shared_future<void>{}});
|
||||
}
|
||||
recurse(synchronizer);
|
||||
uintptr_t synch_id = synchronizer.primitive_id();
|
||||
deps.insert({synch_id, std::shared_future<void>{}});
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
while (!tape.empty()) {
|
||||
@ -305,9 +304,8 @@ void eval(const std::vector<array>& outputs) {
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
}
|
||||
}
|
||||
for (auto id : output_primitives) {
|
||||
deps[id].wait();
|
||||
}
|
||||
|
||||
deps[synch_id].wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
Loading…
Reference in New Issue
Block a user