Fix eval in trace bugs (#612)

* Fix eval in trace bugs

* comment nit
This commit is contained in:
Awni Hannun
2024-02-02 09:57:12 -08:00
committed by GitHub
parent 506d43035c
commit cb6156d35d
4 changed files with 33 additions and 12 deletions

View File

@@ -180,7 +180,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
@@ -197,7 +197,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}

View File

@@ -63,7 +63,15 @@ std::function<void()> make_task(
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());

View File

@@ -1779,10 +1779,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator&&(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
}
return logical_and(a, b);
}
@@ -1797,11 +1793,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator||(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument(
"[operator||] is only supported for bool arrays.");
}
return logical_or(a, b);
}