mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Async eval (#972)
This commit is contained in:
@@ -93,7 +93,9 @@ void array::detach() {
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
if (!is_evaled()) {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
|
@@ -38,7 +38,13 @@ class Synchronizer : public Primitive {
|
||||
// are currently under a function transformation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void eval(std::vector<array> outputs) {
|
||||
std::shared_future<void> async_eval(std::vector<array> outputs) {
|
||||
static std::shared_future<void> global_synchronizer;
|
||||
// Catch up with previous async eval if needed
|
||||
if (global_synchronizer.valid()) {
|
||||
global_synchronizer.wait();
|
||||
}
|
||||
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@@ -152,8 +158,12 @@ void eval(std::vector<array> outputs) {
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
}
|
||||
}
|
||||
global_synchronizer = std::move(deps[synchronizer.id()]);
|
||||
return global_synchronizer;
|
||||
}
|
||||
|
||||
deps[synchronizer.id()].wait();
|
||||
void eval(std::vector<array> outputs) {
|
||||
async_eval(std::move(outputs)).wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
@@ -2,10 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::shared_future<void> async_eval(std::vector<array> outputs);
|
||||
|
||||
void eval(std::vector<array> outputs);
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
|
Reference in New Issue
Block a user