mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Removes the retain_graph
flag (#385)
* Adds global tracing flag * Removes retain_graph in favor of is_tracer
This commit is contained in:

committed by
GitHub

parent
449b43762e
commit
a611b0bc82
@@ -46,39 +46,36 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
|
@@ -25,7 +25,6 @@ std::shared_ptr<void> new_scoped_memory_pool();
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
Reference in New Issue
Block a user