diff --git a/mlx/array.cpp b/mlx/array.cpp index 7f8ee92ec..207496286 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -82,6 +82,12 @@ array::array( } void array::detach() { + for (auto& s : array_desc_->siblings) { + s.array_desc_->inputs.clear(); + s.array_desc_->siblings.clear(); + s.array_desc_->position = 0; + s.array_desc_->primitive = nullptr; + } array_desc_->inputs.clear(); array_desc_->siblings.clear(); array_desc_->position = 0; diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 557397c4f..c120744b1 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -71,9 +71,6 @@ std::function make_task( [s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable { if (!arr.is_tracer()) { arr.detach(); - for (auto s : arr.siblings()) { - s.detach(); - } } p->set_value(); scheduler::notify_task_completion(s); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 7e77ee210..08dd92e0a 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -292,9 +292,6 @@ void eval(const std::vector& outputs) { arr.primitive().eval_cpu(arr.inputs(), outputs); if (!arr.is_tracer()) { arr.detach(); - for (auto s : arr.siblings()) { - s.detach(); - } } if (p) { p->set_value();