Fix detach for multi-output primitives (#480)

This commit is contained in:
Angelos Katharopoulos 2024-01-17 14:08:07 -08:00 committed by GitHub
parent 78102a47ad
commit 135fd796d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 6 deletions

View File

@ -82,6 +82,12 @@ array::array(
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;

View File

@ -71,9 +71,6 @@ std::function<void()> make_task(
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable {
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
p->set_value();
scheduler::notify_task_completion(s);

View File

@ -292,9 +292,6 @@ void eval(const std::vector<array>& outputs) {
arr.primitive().eval_cpu(arr.inputs(), outputs);
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
if (p) {
p->set_value();