mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -9,7 +9,9 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/backend/cpu/eval.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@@ -19,6 +21,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
static constexpr int MAX_ACTIVE_TASKS = 100;
|
||||
|
||||
/* This class is only meant to be used in eval
|
||||
* for synchronizing with the main thread. */
|
||||
class Synchronizer : public Primitive {
|
||||
@@ -43,9 +47,6 @@ int detail::RetainGraph::tracing_counter{0};
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::deque<array> tape;
|
||||
|
||||
// stream events to use for synchronization
|
||||
std::unordered_map<uint32_t, Event> events;
|
||||
|
||||
// Make an effort to choose a good output stream
|
||||
Stream stream = default_stream(default_device());
|
||||
for (auto& o : outputs) {
|
||||
@@ -55,13 +56,17 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<uintptr_t> needs_signal;
|
||||
// Map of array id that needs fence and stream it's computed on
|
||||
std::unordered_map<uintptr_t, uint32_t> needs_fence;
|
||||
|
||||
auto synchronizer = array(
|
||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||
needs_signal.insert(synchronizer.id());
|
||||
|
||||
// Make an event for the synchronizer stream
|
||||
// Stream fences for inter-stream synchronization
|
||||
std::unordered_map<uint32_t, Fence> fences;
|
||||
|
||||
// Stream events for synchronization after eval
|
||||
std::unordered_map<uint32_t, Event> events;
|
||||
events.emplace(stream.index, Event{stream});
|
||||
|
||||
{
|
||||
@@ -78,11 +83,6 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
// Add an input, and continue
|
||||
auto& in = a.inputs()[idx++];
|
||||
|
||||
// Ignore arrays already scheduled
|
||||
if (in.status() == array::Status::scheduled) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (in.status() == array::Status::unscheduled) {
|
||||
if (async && in.is_tracer()) {
|
||||
throw std::invalid_argument(
|
||||
@@ -104,7 +104,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
"https://github.com/ml-explore/mlx/issues.");
|
||||
}
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
needs_signal.insert(in.id());
|
||||
needs_fence.emplace(in.id(), in.primitive().stream().index);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,62 +185,81 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
|
||||
// Lookup corresponding event and increment counter
|
||||
// Lookup corresponding event
|
||||
auto e = events.find(stream.index);
|
||||
if (e == events.end()) {
|
||||
e = events.emplace(stream.index, Event{stream}).first;
|
||||
}
|
||||
e->second.set_value(e->second.value() + 1);
|
||||
e->second.set_value(1);
|
||||
arr.attach_event(e->second);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.attach_event(e->second);
|
||||
}
|
||||
|
||||
// Set the status of the array and siblings.
|
||||
arr.set_status(array::Status::scheduled);
|
||||
for (auto& s : arr.siblings()) {
|
||||
s.set_status(array::Status::scheduled);
|
||||
for (auto& in : arr.inputs()) {
|
||||
if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
|
||||
// Use fence to wait within a single eval
|
||||
// Get the input array's stream fence and wait on the
|
||||
// output arrays stream
|
||||
fences[it->second].wait(stream, in);
|
||||
} else if (in.event().valid()) {
|
||||
if (in.event().is_signaled()) {
|
||||
in.detach_event();
|
||||
} else if (in.event().stream() != stream) {
|
||||
// Use event to wait across async eval
|
||||
in.event().wait(stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
|
||||
|
||||
if (arr.primitive().device() == Device::gpu) {
|
||||
if (!metal::is_available()) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
}
|
||||
scheduler::enqueue(stream, metal::make_task(std::move(arr), signal));
|
||||
metal::eval(arr);
|
||||
} else {
|
||||
auto task = [arr = std::move(arr), stream, signal]() mutable {
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
auto outputs = arr.outputs();
|
||||
try {
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
} catch (const std::exception& error) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
out.set_status(array::Status::available);
|
||||
}
|
||||
cpu::eval(arr);
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
arr.event().signal();
|
||||
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
metal::finalize(e.stream());
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
}
|
||||
|
||||
scheduler::notify_task_completion(stream);
|
||||
};
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
||||
if (needs_fence.find(a.id()) != needs_fence.end()) {
|
||||
auto it = fences.find(stream.index);
|
||||
if (it == fences.end()) {
|
||||
it = fences.emplace(stream.index, Fence{stream}).first;
|
||||
}
|
||||
it->second.update(stream, a);
|
||||
}
|
||||
};
|
||||
|
||||
arr.set_status(array::Status::evaluated);
|
||||
// TODO Maybe always want the fence coherent kernel in the same cbuf
|
||||
// as the other kernels?
|
||||
maybe_update_fence(arr);
|
||||
for (auto& sib : arr.siblings()) {
|
||||
sib.set_status(array::Status::evaluated);
|
||||
maybe_update_fence(sib);
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the event in its stream
|
||||
for (auto& [_, e] : events) {
|
||||
auto s = e.stream();
|
||||
e.signal(s);
|
||||
if (s.device == Device::gpu) {
|
||||
metal::finalize(s);
|
||||
}
|
||||
}
|
||||
|
||||
return synchronizer;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user