From b74bc5e634a01e66986b8ded42bb12b7531b45db Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 31 Jul 2025 03:55:29 -0700 Subject: [PATCH] [CUDA] Saving primitive inputs faster --- mlx/backend/cuda/eval.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 0e1477e95..802068282 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -36,18 +36,17 @@ void eval(array& arr) { auto& encoder = cu::get_command_encoder(arr.primitive().stream()); // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); + // Except for the donated one. + if (in.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(in); + } } for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); + if (s.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(s); + } } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); encoder.maybe_commit(); }