Async eval (#972)

This commit is contained in:
Awni Hannun
2024-04-09 18:34:00 -07:00
committed by GitHub
parent fffe072028
commit 99abb9eff4
5 changed files with 70 additions and 3 deletions

View File

@@ -93,7 +93,9 @@ void array::detach() {
}
void array::eval() {
mlx::core::eval({*this});
if (!is_evaled()) {
mlx::core::eval({*this});
}
}
bool array::is_tracer() const {

View File

@@ -38,7 +38,13 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(std::vector<array> outputs) {
std::shared_future<void> async_eval(std::vector<array> outputs) {
static std::shared_future<void> global_synchronizer;
// Catch up with previous async eval if needed
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@@ -152,8 +158,12 @@ void eval(std::vector<array> outputs) {
scheduler::enqueue(stream, std::move(task));
}
}
global_synchronizer = std::move(deps[synchronizer.id()]);
return global_synchronizer;
}
deps[synchronizer.id()].wait();
void eval(std::vector<array> outputs) {
async_eval(std::move(outputs)).wait();
}
std::pair<std::vector<array>, std::vector<array>> vjp(

View File

@@ -2,10 +2,13 @@
#pragma once
#include <future>
#include "mlx/array.h"
namespace mlx::core {
std::shared_future<void> async_eval(std::vector<array> outputs);
void eval(std::vector<array> outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>