diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 0e1477e95..802068282 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -36,18 +36,17 @@ void eval(array& arr) { auto& encoder = cu::get_command_encoder(arr.primitive().stream()); // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); + // Except for the donated one. + if (in.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(in); + } } for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); + if (s.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(s); + } } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); encoder.maybe_commit(); }