redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -9,7 +9,9 @@
#include <unordered_map>
#include <unordered_set>
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/fence.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@@ -19,6 +21,8 @@
namespace mlx::core {
static constexpr int MAX_ACTIVE_TASKS = 100;
/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
@@ -43,9 +47,6 @@ int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
std::deque<array> tape;
// stream events to use for synchronization
std::unordered_map<uint32_t, Event> events;
// Make an effort to choose a good output stream
Stream stream = default_stream(default_device());
for (auto& o : outputs) {
@@ -55,13 +56,17 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
std::unordered_set<uintptr_t> needs_signal;
// Map of array id that needs fence and stream it's computed on
std::unordered_map<uintptr_t, uint32_t> needs_fence;
auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
needs_signal.insert(synchronizer.id());
// Make an event for the synchronizer stream
// Stream fences for inter-stream synchronization
std::unordered_map<uint32_t, Fence> fences;
// Stream events for synchronization after eval
std::unordered_map<uint32_t, Event> events;
events.emplace(stream.index, Event{stream});
{
@@ -78,11 +83,6 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
// Ignore arrays already scheduled
if (in.status() == array::Status::scheduled) {
continue;
}
if (in.status() == array::Status::unscheduled) {
if (async && in.is_tracer()) {
throw std::invalid_argument(
@@ -104,7 +104,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
"https://github.com/ml-explore/mlx/issues.");
}
if (a.primitive().stream() != in.primitive().stream()) {
needs_signal.insert(in.id());
needs_fence.emplace(in.id(), in.primitive().stream().index);
}
}
@@ -185,62 +185,81 @@ array eval_impl(std::vector<array> outputs, bool async) {
auto stream = arr.primitive().stream();
// Lookup corresponding event and increment counter
// Lookup corresponding event
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
e->second.set_value(1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}
// Set the status of the array and siblings.
arr.set_status(array::Status::scheduled);
for (auto& s : arr.siblings()) {
s.set_status(array::Status::scheduled);
for (auto& in : arr.inputs()) {
if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
// Use fence to wait within a single eval
// Get the input array's stream fence and wait on the
// output arrays stream
fences[it->second].wait(stream, in);
} else if (in.event().valid()) {
if (in.event().is_signaled()) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval
in.event().wait(stream);
}
}
}
std::vector<std::shared_future<void>> arr_deps;
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
throw std::runtime_error("Metal GPU is not available.");
}
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
metal::eval(arr);
} else {
auto task = [arr = std::move(arr), stream, signal]() mutable {
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
input.event().wait();
}
}
scheduler::notify_new_task(stream);
auto outputs = arr.outputs();
try {
arr.primitive().eval_cpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::available);
}
cpu::eval(arr);
}
if (signal) {
arr.event().signal();
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
metal::finalize(e.stream());
}
}
scheduler::wait_for_one();
}
scheduler::notify_task_completion(stream);
};
scheduler::enqueue(stream, std::move(task));
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
if (needs_fence.find(a.id()) != needs_fence.end()) {
auto it = fences.find(stream.index);
if (it == fences.end()) {
it = fences.emplace(stream.index, Fence{stream}).first;
}
it->second.update(stream, a);
}
};
arr.set_status(array::Status::evaluated);
// TODO Maybe always want the fence coherent kernel in the same cbuf
// as the other kernels?
maybe_update_fence(arr);
for (auto& sib : arr.siblings()) {
sib.set_status(array::Status::evaluated);
maybe_update_fence(sib);
}
if (!arr.is_tracer()) {
arr.detach();
}
}
// Signal the event in its stream
for (auto& [_, e] : events) {
auto s = e.stream();
e.signal(s);
if (s.device == Device::gpu) {
metal::finalize(s);
}
}
return synchronizer;
}