Fix eval in trace bugs (#612)

* Fix eval in trace bugs

* comment nit
This commit is contained in:
Awni Hannun
2024-02-02 09:57:12 -08:00
committed by GitHub
parent 506d43035c
commit cb6156d35d
4 changed files with 33 additions and 12 deletions

View File

@@ -63,7 +63,15 @@ std::function<void()> make_task(
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());