mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Fix detach for multi-output primitives (#480)
This commit is contained in:
parent
78102a47ad
commit
135fd796d2
@ -82,6 +82,12 @@ array::array(
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
|
@ -71,9 +71,6 @@ std::function<void()> make_task(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
for (auto s : arr.siblings()) {
|
||||
s.detach();
|
||||
}
|
||||
}
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
|
@ -292,9 +292,6 @@ void eval(const std::vector<array>& outputs) {
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
for (auto s : arr.siblings()) {
|
||||
s.detach();
|
||||
}
|
||||
}
|
||||
if (p) {
|
||||
p->set_value();
|
||||
|
Loading…
Reference in New Issue
Block a user